infrahub-server 1.4.9__py3-none-any.whl → 1.4.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- infrahub/api/oauth2.py +13 -19
- infrahub/api/oidc.py +15 -21
- infrahub/artifacts/models.py +2 -1
- infrahub/auth.py +137 -3
- infrahub/cli/db.py +24 -0
- infrahub/cli/db_commands/clean_duplicate_schema_fields.py +212 -0
- infrahub/computed_attribute/tasks.py +1 -1
- infrahub/core/changelog/models.py +2 -2
- infrahub/core/diff/query/artifact.py +12 -9
- infrahub/core/ipam/utilization.py +1 -1
- infrahub/core/manager.py +6 -3
- infrahub/core/node/__init__.py +3 -1
- infrahub/core/node/constraints/attribute_uniqueness.py +3 -1
- infrahub/core/node/create.py +12 -3
- infrahub/core/registry.py +2 -2
- infrahub/core/relationship/constraints/count.py +1 -1
- infrahub/core/relationship/model.py +1 -1
- infrahub/core/schema/definitions/internal.py +4 -0
- infrahub/core/schema/manager.py +19 -1
- infrahub/core/schema/node_schema.py +4 -2
- infrahub/core/schema/schema_branch.py +8 -0
- infrahub/core/validators/determiner.py +12 -1
- infrahub/core/validators/relationship/peer.py +1 -1
- infrahub/core/validators/tasks.py +1 -1
- infrahub/generators/tasks.py +3 -7
- infrahub/git/integrator.py +1 -1
- infrahub/git/models.py +2 -1
- infrahub/git/repository.py +22 -5
- infrahub/git/tasks.py +14 -8
- infrahub/git/utils.py +123 -1
- infrahub/graphql/analyzer.py +1 -1
- infrahub/graphql/mutations/main.py +3 -3
- infrahub/graphql/mutations/schema.py +5 -5
- infrahub/message_bus/types.py +2 -1
- infrahub/middleware.py +26 -1
- infrahub/proposed_change/tasks.py +11 -12
- infrahub/server.py +12 -3
- infrahub/workers/dependencies.py +8 -1
- {infrahub_server-1.4.9.dist-info → infrahub_server-1.4.11.dist-info}/METADATA +17 -17
- {infrahub_server-1.4.9.dist-info → infrahub_server-1.4.11.dist-info}/RECORD +46 -45
- infrahub_testcontainers/container.py +1 -1
- infrahub_testcontainers/docker-compose-cluster.test.yml +1 -1
- infrahub_testcontainers/docker-compose.test.yml +1 -1
- {infrahub_server-1.4.9.dist-info → infrahub_server-1.4.11.dist-info}/LICENSE.txt +0 -0
- {infrahub_server-1.4.9.dist-info → infrahub_server-1.4.11.dist-info}/WHEEL +0 -0
- {infrahub_server-1.4.9.dist-info → infrahub_server-1.4.11.dist-info}/entry_points.txt +0 -0
infrahub/api/oauth2.py
CHANGED
|
@@ -11,14 +11,16 @@ from opentelemetry import trace
|
|
|
11
11
|
|
|
12
12
|
from infrahub import config, models
|
|
13
13
|
from infrahub.api.dependencies import get_db
|
|
14
|
-
from infrahub.auth import
|
|
15
|
-
|
|
14
|
+
from infrahub.auth import (
|
|
15
|
+
get_groups_from_provider,
|
|
16
|
+
signin_sso_account,
|
|
17
|
+
validate_auth_response,
|
|
18
|
+
)
|
|
19
|
+
from infrahub.exceptions import ProcessingError
|
|
16
20
|
from infrahub.log import get_logger
|
|
17
21
|
from infrahub.message_bus.types import KVTTL
|
|
18
22
|
|
|
19
23
|
if TYPE_CHECKING:
|
|
20
|
-
import httpx
|
|
21
|
-
|
|
22
24
|
from infrahub.database import InfrahubDatabase
|
|
23
25
|
from infrahub.services import InfrahubServices
|
|
24
26
|
|
|
@@ -95,7 +97,7 @@ async def token(
|
|
|
95
97
|
}
|
|
96
98
|
|
|
97
99
|
token_response = await service.http.post(provider.token_url, data=token_data)
|
|
98
|
-
|
|
100
|
+
validate_auth_response(response=token_response, provider_type="OAuth 2.0")
|
|
99
101
|
|
|
100
102
|
with trace.get_tracer(__name__).start_as_current_span("sso_token_request") as span:
|
|
101
103
|
span.set_attribute("token_request_data", ujson.dumps(token_response.json()))
|
|
@@ -107,12 +109,17 @@ async def token(
|
|
|
107
109
|
else:
|
|
108
110
|
userinfo_response = await service.http.post(provider.userinfo_url, headers=headers)
|
|
109
111
|
|
|
110
|
-
|
|
112
|
+
validate_auth_response(response=userinfo_response, provider_type="OAuth 2.0")
|
|
111
113
|
user_info = userinfo_response.json()
|
|
112
114
|
sso_groups = user_info.get("groups", []) or await get_groups_from_provider(
|
|
113
115
|
provider=provider, service=service, payload=payload, user_info=user_info
|
|
114
116
|
)
|
|
115
117
|
|
|
118
|
+
log.info(
|
|
119
|
+
"SSO user authenticated",
|
|
120
|
+
body={"user_name": user_info.get("name"), "groups": sso_groups},
|
|
121
|
+
)
|
|
122
|
+
|
|
116
123
|
if not sso_groups and config.SETTINGS.security.sso_user_default_group:
|
|
117
124
|
sso_groups = [config.SETTINGS.security.sso_user_default_group]
|
|
118
125
|
|
|
@@ -134,16 +141,3 @@ async def token(
|
|
|
134
141
|
return models.UserTokenWithUrl(
|
|
135
142
|
access_token=user_token.access_token, refresh_token=user_token.refresh_token, final_url=stored_final_url
|
|
136
143
|
)
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
def _validate_response(response: httpx.Response) -> None:
|
|
140
|
-
if 200 <= response.status_code <= 299:
|
|
141
|
-
return
|
|
142
|
-
|
|
143
|
-
log.error(
|
|
144
|
-
"Invalid response from the OAuth provider",
|
|
145
|
-
url=response.url,
|
|
146
|
-
status_code=response.status_code,
|
|
147
|
-
body=response.json(),
|
|
148
|
-
)
|
|
149
|
-
raise GatewayError(message="Invalid response from Authentication provider")
|
infrahub/api/oidc.py
CHANGED
|
@@ -13,14 +13,16 @@ from pydantic import BaseModel, HttpUrl
|
|
|
13
13
|
|
|
14
14
|
from infrahub import config, models
|
|
15
15
|
from infrahub.api.dependencies import get_db
|
|
16
|
-
from infrahub.auth import
|
|
17
|
-
|
|
16
|
+
from infrahub.auth import (
|
|
17
|
+
get_groups_from_provider,
|
|
18
|
+
signin_sso_account,
|
|
19
|
+
validate_auth_response,
|
|
20
|
+
)
|
|
21
|
+
from infrahub.exceptions import ProcessingError
|
|
18
22
|
from infrahub.log import get_logger
|
|
19
23
|
from infrahub.message_bus.types import KVTTL
|
|
20
24
|
|
|
21
25
|
if TYPE_CHECKING:
|
|
22
|
-
import httpx
|
|
23
|
-
|
|
24
26
|
from infrahub.database import InfrahubDatabase
|
|
25
27
|
from infrahub.services import InfrahubServices
|
|
26
28
|
|
|
@@ -69,7 +71,7 @@ async def authorize(request: Request, provider_name: str, final_url: str | None
|
|
|
69
71
|
service: InfrahubServices = request.app.state.service
|
|
70
72
|
|
|
71
73
|
response = await service.http.get(url=provider.discovery_url)
|
|
72
|
-
|
|
74
|
+
validate_auth_response(response=response, provider_type="OIDC")
|
|
73
75
|
oidc_config = OIDCDiscoveryConfig(**response.json())
|
|
74
76
|
|
|
75
77
|
with trace.get_tracer(__name__).start_as_current_span("sso_oauth2_client_configuration") as span:
|
|
@@ -129,12 +131,12 @@ async def token(
|
|
|
129
131
|
}
|
|
130
132
|
|
|
131
133
|
discovery_response = await service.http.get(url=provider.discovery_url)
|
|
132
|
-
|
|
134
|
+
validate_auth_response(response=discovery_response, provider_type="OIDC")
|
|
133
135
|
|
|
134
136
|
oidc_config = OIDCDiscoveryConfig(**discovery_response.json())
|
|
135
137
|
|
|
136
138
|
token_response = await service.http.post(str(oidc_config.token_endpoint), data=token_data)
|
|
137
|
-
|
|
139
|
+
validate_auth_response(response=token_response, provider_type="OIDC")
|
|
138
140
|
|
|
139
141
|
with trace.get_tracer(__name__).start_as_current_span("sso_token_request") as span:
|
|
140
142
|
span.set_attribute("token_request_data", ujson.dumps(token_response.json()))
|
|
@@ -147,7 +149,7 @@ async def token(
|
|
|
147
149
|
else:
|
|
148
150
|
userinfo_response = await service.http.post(str(oidc_config.userinfo_endpoint), headers=headers)
|
|
149
151
|
|
|
150
|
-
|
|
152
|
+
validate_auth_response(response=userinfo_response, provider_type="OIDC")
|
|
151
153
|
user_info: dict[str, Any] = userinfo_response.json()
|
|
152
154
|
sso_groups = (
|
|
153
155
|
user_info.get("groups")
|
|
@@ -157,6 +159,11 @@ async def token(
|
|
|
157
159
|
or await get_groups_from_provider(provider=provider, service=service, payload=payload, user_info=user_info)
|
|
158
160
|
)
|
|
159
161
|
|
|
162
|
+
log.info(
|
|
163
|
+
"SSO user authenticated",
|
|
164
|
+
body={"user_name": user_info.get("name"), "groups": sso_groups},
|
|
165
|
+
)
|
|
166
|
+
|
|
160
167
|
if not sso_groups and config.SETTINGS.security.sso_user_default_group:
|
|
161
168
|
sso_groups = [config.SETTINGS.security.sso_user_default_group]
|
|
162
169
|
|
|
@@ -180,19 +187,6 @@ async def token(
|
|
|
180
187
|
)
|
|
181
188
|
|
|
182
189
|
|
|
183
|
-
def _validate_response(response: httpx.Response) -> None:
|
|
184
|
-
if 200 <= response.status_code <= 299:
|
|
185
|
-
return
|
|
186
|
-
|
|
187
|
-
log.error(
|
|
188
|
-
"Invalid response from the OIDC provider",
|
|
189
|
-
url=response.url,
|
|
190
|
-
status_code=response.status_code,
|
|
191
|
-
body=response.json(),
|
|
192
|
-
)
|
|
193
|
-
raise GatewayError(message="Invalid response from Authentication provider")
|
|
194
|
-
|
|
195
|
-
|
|
196
190
|
async def _get_id_token_groups(
|
|
197
191
|
oidc_config: OIDCDiscoveryConfig, service: InfrahubServices, payload: dict[str, Any], client_id: str
|
|
198
192
|
) -> list[str]:
|
infrahub/artifacts/models.py
CHANGED
|
@@ -25,7 +25,8 @@ class CheckArtifactCreate(BaseModel):
|
|
|
25
25
|
target_kind: str = Field(..., description="The kind of the target object for this artifact")
|
|
26
26
|
target_name: str = Field(..., description="Name of the artifact target")
|
|
27
27
|
artifact_id: str | None = Field(default=None, description="The id of the artifact if it previously existed")
|
|
28
|
-
query: str = Field(..., description="The name of the query to use when collecting data")
|
|
28
|
+
query: str = Field(..., description="The name of the query to use when collecting data") # Deprecated
|
|
29
|
+
query_id: str = Field(..., description="The id of the query to use when collecting data")
|
|
29
30
|
timeout: int = Field(..., description="Timeout for requests used to generate this artifact")
|
|
30
31
|
variables: dict = Field(..., description="Input variables when generating the artifact")
|
|
31
32
|
validator_id: str = Field(..., description="The ID of the validator")
|
infrahub/auth.py
CHANGED
|
@@ -3,27 +3,37 @@ from __future__ import annotations
|
|
|
3
3
|
import uuid
|
|
4
4
|
from datetime import datetime, timedelta, timezone
|
|
5
5
|
from enum import Enum
|
|
6
|
-
from typing import TYPE_CHECKING
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
7
|
|
|
8
8
|
import bcrypt
|
|
9
9
|
import jwt
|
|
10
10
|
from pydantic import BaseModel
|
|
11
11
|
|
|
12
12
|
from infrahub import config, models
|
|
13
|
-
from infrahub.config import
|
|
13
|
+
from infrahub.config import (
|
|
14
|
+
SecurityOAuth2Google,
|
|
15
|
+
SecurityOAuth2Settings,
|
|
16
|
+
SecurityOIDCGoogle,
|
|
17
|
+
SecurityOIDCSettings,
|
|
18
|
+
)
|
|
14
19
|
from infrahub.core.account import validate_token
|
|
15
20
|
from infrahub.core.constants import AccountStatus, InfrahubKind
|
|
16
21
|
from infrahub.core.manager import NodeManager
|
|
17
22
|
from infrahub.core.node import Node
|
|
18
23
|
from infrahub.core.protocols import CoreAccount, CoreAccountGroup
|
|
19
24
|
from infrahub.core.registry import registry
|
|
20
|
-
from infrahub.exceptions import AuthorizationError, NodeNotFoundError
|
|
25
|
+
from infrahub.exceptions import AuthorizationError, GatewayError, NodeNotFoundError
|
|
26
|
+
from infrahub.log import get_logger
|
|
21
27
|
|
|
22
28
|
if TYPE_CHECKING:
|
|
29
|
+
import httpx
|
|
30
|
+
|
|
23
31
|
from infrahub.core.protocols import CoreGenericAccount
|
|
24
32
|
from infrahub.database import InfrahubDatabase
|
|
25
33
|
from infrahub.services import InfrahubServices
|
|
26
34
|
|
|
35
|
+
log = get_logger()
|
|
36
|
+
|
|
27
37
|
|
|
28
38
|
class AuthType(str, Enum):
|
|
29
39
|
NONE = "none"
|
|
@@ -256,3 +266,127 @@ async def get_groups_from_provider(
|
|
|
256
266
|
return [membership["groupKey"]["id"] for membership in group_memberships["memberships"]]
|
|
257
267
|
|
|
258
268
|
return []
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def safe_get_response_body(response: httpx.Response, raise_error_on_empty_body: bool = True) -> str | dict[str, Any]:
|
|
272
|
+
"""Safely extract response body from HTTP response. If the response body cannot be JSON parsed or is empty,
|
|
273
|
+
it raises a GatewayError.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
response: The HTTP response object
|
|
277
|
+
raise_error_on_empty_body: Whether to raise an error if the response body is empty
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
The response body as JSON dict if possible, otherwise as text
|
|
281
|
+
|
|
282
|
+
Raises:
|
|
283
|
+
GatewayError: When the response body cannot be parsed or is empty
|
|
284
|
+
"""
|
|
285
|
+
# Try to parse as JSON first
|
|
286
|
+
try:
|
|
287
|
+
return response.json()
|
|
288
|
+
except Exception as json_error:
|
|
289
|
+
try:
|
|
290
|
+
# Try to get as text
|
|
291
|
+
text_body = response.text
|
|
292
|
+
if not text_body.strip() and raise_error_on_empty_body: # Check for empty or whitespace-only response
|
|
293
|
+
log.error(
|
|
294
|
+
"Empty response body from authentication provider",
|
|
295
|
+
url=str(response.url),
|
|
296
|
+
status_code=response.status_code,
|
|
297
|
+
)
|
|
298
|
+
raise GatewayError(message="Authentication provider returned an empty response") from json_error
|
|
299
|
+
except Exception:
|
|
300
|
+
log.error(
|
|
301
|
+
"Unable to read response body from authentication provider",
|
|
302
|
+
url=str(response.url),
|
|
303
|
+
status_code=response.status_code,
|
|
304
|
+
)
|
|
305
|
+
raise GatewayError(message="Unable to read response from authentication provider") from json_error
|
|
306
|
+
|
|
307
|
+
# Here it means we got a text response but not JSON
|
|
308
|
+
return text_body
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def extract_auth_error_message(response_body: str | dict[str, Any], base_message: str) -> str:
|
|
312
|
+
"""Extract error message from OAuth 2.0/OIDC provider response following RFC 6749.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
response_body: The response body from the authentication provider
|
|
316
|
+
base_message: Base error message to use if no specific error is found
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
Formatted error message with provider details if available
|
|
320
|
+
"""
|
|
321
|
+
if not isinstance(response_body, dict):
|
|
322
|
+
return base_message
|
|
323
|
+
|
|
324
|
+
# RFC 6749 standard error response format
|
|
325
|
+
error_description = response_body.get("error_description")
|
|
326
|
+
error_code = response_body.get("error")
|
|
327
|
+
|
|
328
|
+
if error_description:
|
|
329
|
+
return f"{base_message}: {error_description}"
|
|
330
|
+
if error_code:
|
|
331
|
+
return f"{base_message}: {error_code}"
|
|
332
|
+
|
|
333
|
+
return base_message
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def validate_auth_response(response: httpx.Response, provider_type: str = "authentication") -> None:
|
|
337
|
+
"""Validate HTTP response from OAuth 2.0/OIDC provider and raise appropriate errors.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
response: The HTTP response from the authentication provider
|
|
341
|
+
provider_type: Type of provider for logging (e.g., "OAuth 2.0", "OIDC")
|
|
342
|
+
|
|
343
|
+
Raises:
|
|
344
|
+
GatewayError: When the response indicates an error or invalid state
|
|
345
|
+
"""
|
|
346
|
+
# If the status code is successful, simply return
|
|
347
|
+
if 200 <= response.status_code <= 299:
|
|
348
|
+
# Verify that we can read the response body safely and it is not empty
|
|
349
|
+
safe_get_response_body(response)
|
|
350
|
+
return
|
|
351
|
+
|
|
352
|
+
# Prepare variables with default values for logging
|
|
353
|
+
response_body = safe_get_response_body(response, raise_error_on_empty_body=False)
|
|
354
|
+
log_message: str = f"Unexpected response from {provider_type} provider"
|
|
355
|
+
base_msg: str = "Unexpected response from authentication provider."
|
|
356
|
+
|
|
357
|
+
# Handle specific HTTP status codes with appropriate error messages
|
|
358
|
+
match response.status_code:
|
|
359
|
+
case 400:
|
|
360
|
+
log_message = f"Bad request to {provider_type} provider"
|
|
361
|
+
base_msg = "Bad request to authentication provider. Please try again later or contact your administrator."
|
|
362
|
+
|
|
363
|
+
case 401:
|
|
364
|
+
log_message = f"Unauthorized request to {provider_type} provider"
|
|
365
|
+
base_msg = (
|
|
366
|
+
"Unauthorized request to authentication provider. Please try again later or contact your administrator."
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
case 403:
|
|
370
|
+
log_message = f"Forbidden request to {provider_type} provider"
|
|
371
|
+
base_msg = (
|
|
372
|
+
"Access forbidden by authentication provider. Please try again later or contact your administrator."
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
case 404:
|
|
376
|
+
log_message = f"Resource not found for {provider_type} provider"
|
|
377
|
+
base_msg = (
|
|
378
|
+
"Authentication provider endpoint not found. Please try again later or contact your administrator."
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
case 429:
|
|
382
|
+
log_message = f"Rate limited by {provider_type} provider"
|
|
383
|
+
base_msg = "Rate limited by authentication provider. Please try again later."
|
|
384
|
+
|
|
385
|
+
case status_code if 500 <= status_code <= 599:
|
|
386
|
+
log_message = f"Server error from {provider_type} provider"
|
|
387
|
+
base_msg = "Authentication provider is experiencing server issues. Please try again later or contact your administrator."
|
|
388
|
+
|
|
389
|
+
# Print proper log and raise gateway error
|
|
390
|
+
log.error(log_message, url=str(response.url), status_code=response.status_code, body=response_body)
|
|
391
|
+
error_msg = extract_auth_error_message(response_body, base_msg)
|
|
392
|
+
raise GatewayError(message=error_msg)
|
infrahub/cli/db.py
CHANGED
|
@@ -54,6 +54,7 @@ from infrahub.log import get_logger
|
|
|
54
54
|
|
|
55
55
|
from .constants import ERROR_BADGE, FAILED_BADGE, SUCCESS_BADGE
|
|
56
56
|
from .db_commands.check_inheritance import check_inheritance
|
|
57
|
+
from .db_commands.clean_duplicate_schema_fields import clean_duplicate_schema_fields
|
|
57
58
|
from .patch import patch_app
|
|
58
59
|
|
|
59
60
|
|
|
@@ -200,6 +201,29 @@ async def check_inheritance_cmd(
|
|
|
200
201
|
await dbdriver.close()
|
|
201
202
|
|
|
202
203
|
|
|
204
|
+
@app.command(name="check-duplicate-schema-fields")
|
|
205
|
+
async def check_duplicate_schema_fields_cmd(
|
|
206
|
+
ctx: typer.Context,
|
|
207
|
+
fix: bool = typer.Option(False, help="Fix the duplicate schema fields on the default branch."),
|
|
208
|
+
config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"),
|
|
209
|
+
) -> None:
|
|
210
|
+
"""Check for any duplicate schema attributes or relationships on the default branch"""
|
|
211
|
+
logging.getLogger("infrahub").setLevel(logging.WARNING)
|
|
212
|
+
logging.getLogger("neo4j").setLevel(logging.ERROR)
|
|
213
|
+
logging.getLogger("prefect").setLevel(logging.ERROR)
|
|
214
|
+
|
|
215
|
+
config.load_and_exit(config_file_name=config_file)
|
|
216
|
+
|
|
217
|
+
context: CliContext = ctx.obj
|
|
218
|
+
dbdriver = await context.init_db(retry=1)
|
|
219
|
+
|
|
220
|
+
success = await clean_duplicate_schema_fields(db=dbdriver, fix=fix)
|
|
221
|
+
if not success:
|
|
222
|
+
raise typer.Exit(code=1)
|
|
223
|
+
|
|
224
|
+
await dbdriver.close()
|
|
225
|
+
|
|
226
|
+
|
|
203
227
|
@app.command(name="update-core-schema")
|
|
204
228
|
async def update_core_schema_cmd(
|
|
205
229
|
ctx: typer.Context,
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from rich import print as rprint
|
|
6
|
+
from rich.console import Console
|
|
7
|
+
from rich.table import Table
|
|
8
|
+
|
|
9
|
+
from infrahub.cli.constants import FAILED_BADGE, SUCCESS_BADGE
|
|
10
|
+
from infrahub.core.query import Query, QueryType
|
|
11
|
+
from infrahub.database import InfrahubDatabase
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SchemaFieldType(str, Enum):
|
|
15
|
+
ATTRIBUTE = "attribute"
|
|
16
|
+
RELATIONSHIP = "relationship"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class SchemaFieldDetails:
|
|
21
|
+
schema_kind: str
|
|
22
|
+
schema_uuid: str
|
|
23
|
+
field_type: SchemaFieldType
|
|
24
|
+
field_name: str
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DuplicateSchemaFields(Query):
|
|
28
|
+
async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002
|
|
29
|
+
query = """
|
|
30
|
+
MATCH (root:Root)
|
|
31
|
+
LIMIT 1
|
|
32
|
+
WITH root.default_branch AS default_branch
|
|
33
|
+
MATCH (field:SchemaAttribute|SchemaRelationship)
|
|
34
|
+
CALL (default_branch, field) {
|
|
35
|
+
MATCH (field)-[is_part_of:IS_PART_OF]->(:Root)
|
|
36
|
+
WHERE is_part_of.branch = default_branch
|
|
37
|
+
ORDER BY is_part_of.from DESC
|
|
38
|
+
RETURN is_part_of
|
|
39
|
+
LIMIT 1
|
|
40
|
+
}
|
|
41
|
+
WITH default_branch, field, CASE
|
|
42
|
+
WHEN is_part_of.status = "active" AND is_part_of.to IS NULL THEN is_part_of.from
|
|
43
|
+
ELSE NULL
|
|
44
|
+
END AS active_from
|
|
45
|
+
WHERE active_from IS NOT NULL
|
|
46
|
+
WITH default_branch, field, active_from, "SchemaAttribute" IN labels(field) AS is_attribute
|
|
47
|
+
CALL (field, default_branch) {
|
|
48
|
+
MATCH (field)-[r1:HAS_ATTRIBUTE]->(:Attribute {name: "name"})-[r2:HAS_VALUE]->(name_value:AttributeValue)
|
|
49
|
+
WHERE r1.branch = default_branch AND r2.branch = default_branch
|
|
50
|
+
AND r1.status = "active" AND r2.status = "active"
|
|
51
|
+
AND r1.to IS NULL AND r2.to IS NULL
|
|
52
|
+
ORDER BY r1.from DESC, r1.status ASC, r2.from DESC, r2.status ASC
|
|
53
|
+
LIMIT 1
|
|
54
|
+
RETURN name_value.value AS field_name
|
|
55
|
+
}
|
|
56
|
+
CALL (field, default_branch) {
|
|
57
|
+
MATCH (field)-[r1:IS_RELATED]-(rel:Relationship)-[r2:IS_RELATED]-(peer:SchemaNode|SchemaGeneric)
|
|
58
|
+
WHERE rel.name IN ["schema__node__relationships", "schema__node__attributes"]
|
|
59
|
+
AND r1.branch = default_branch AND r2.branch = default_branch
|
|
60
|
+
AND r1.status = "active" AND r2.status = "active"
|
|
61
|
+
AND r1.to IS NULL AND r2.to IS NULL
|
|
62
|
+
ORDER BY r1.from DESC, r1.status ASC, r2.from DESC, r2.status ASC
|
|
63
|
+
LIMIT 1
|
|
64
|
+
RETURN peer AS schema_vertex
|
|
65
|
+
}
|
|
66
|
+
WITH default_branch, field, field_name, is_attribute, active_from, schema_vertex
|
|
67
|
+
ORDER BY active_from DESC
|
|
68
|
+
WITH default_branch, field_name, is_attribute, schema_vertex, collect(field) AS fields_reverse_chron
|
|
69
|
+
WHERE size(fields_reverse_chron) > 1
|
|
70
|
+
"""
|
|
71
|
+
self.add_to_query(query)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class GetDuplicateSchemaFields(DuplicateSchemaFields):
|
|
75
|
+
"""
|
|
76
|
+
Get the kind, field type, and field name for any duplicated attributes or relationships on a given schema
|
|
77
|
+
on the default branch
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
name = "get_duplicate_schema_fields"
|
|
81
|
+
type = QueryType.READ
|
|
82
|
+
insert_return = False
|
|
83
|
+
|
|
84
|
+
async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None:
|
|
85
|
+
await super().query_init(db=db, **kwargs)
|
|
86
|
+
query = """
|
|
87
|
+
CALL (schema_vertex, default_branch) {
|
|
88
|
+
MATCH (schema_vertex)-[r1:HAS_ATTRIBUTE]->(:Attribute {name: "namespace"})-[r2:HAS_VALUE]->(name_value:AttributeValue)
|
|
89
|
+
WHERE r1.branch = default_branch AND r2.branch = default_branch
|
|
90
|
+
ORDER BY r1.from DESC, r1.status ASC, r2.from DESC, r2.status ASC
|
|
91
|
+
LIMIT 1
|
|
92
|
+
RETURN name_value.value AS schema_namespace
|
|
93
|
+
}
|
|
94
|
+
CALL (schema_vertex, default_branch) {
|
|
95
|
+
MATCH (schema_vertex)-[r1:HAS_ATTRIBUTE]->(:Attribute {name: "name"})-[r2:HAS_VALUE]->(name_value:AttributeValue)
|
|
96
|
+
WHERE r1.branch = default_branch AND r2.branch = default_branch
|
|
97
|
+
ORDER BY r1.from DESC, r1.status ASC, r2.from DESC, r2.status ASC
|
|
98
|
+
LIMIT 1
|
|
99
|
+
RETURN name_value.value AS schema_name
|
|
100
|
+
}
|
|
101
|
+
RETURN schema_namespace + schema_name AS schema_kind, schema_vertex.uuid AS schema_uuid, field_name, is_attribute
|
|
102
|
+
ORDER BY schema_kind ASC, is_attribute DESC, field_name ASC
|
|
103
|
+
"""
|
|
104
|
+
self.return_labels = ["schema_kind", "schema_uuid", "field_name", "is_attribute"]
|
|
105
|
+
self.add_to_query(query)
|
|
106
|
+
|
|
107
|
+
def get_schema_field_details(self) -> list[SchemaFieldDetails]:
|
|
108
|
+
schema_field_details: list[SchemaFieldDetails] = []
|
|
109
|
+
for result in self.results:
|
|
110
|
+
schema_kind = result.get_as_type(label="schema_kind", return_type=str)
|
|
111
|
+
schema_uuid = result.get_as_type(label="schema_uuid", return_type=str)
|
|
112
|
+
field_name = result.get_as_type(label="field_name", return_type=str)
|
|
113
|
+
is_attribute = result.get_as_type(label="is_attribute", return_type=bool)
|
|
114
|
+
schema_field_details.append(
|
|
115
|
+
SchemaFieldDetails(
|
|
116
|
+
schema_kind=schema_kind,
|
|
117
|
+
schema_uuid=schema_uuid,
|
|
118
|
+
field_name=field_name,
|
|
119
|
+
field_type=SchemaFieldType.ATTRIBUTE if is_attribute else SchemaFieldType.RELATIONSHIP,
|
|
120
|
+
)
|
|
121
|
+
)
|
|
122
|
+
return schema_field_details
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class FixDuplicateSchemaFields(DuplicateSchemaFields):
|
|
126
|
+
"""
|
|
127
|
+
Fix the duplicate schema fields by hard deleting the earlier duplicate(s)
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
name = "fix_duplicate_schema_fields"
|
|
131
|
+
type = QueryType.WRITE
|
|
132
|
+
insert_return = False
|
|
133
|
+
|
|
134
|
+
async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None:
|
|
135
|
+
await super().query_init(db=db, **kwargs)
|
|
136
|
+
query = """
|
|
137
|
+
WITH default_branch, tail(fields_reverse_chron) AS fields_to_delete
|
|
138
|
+
UNWIND fields_to_delete AS field_to_delete
|
|
139
|
+
CALL (field_to_delete, default_branch) {
|
|
140
|
+
MATCH (field_to_delete)-[r:IS_PART_OF {branch: default_branch}]-()
|
|
141
|
+
DELETE r
|
|
142
|
+
WITH field_to_delete
|
|
143
|
+
MATCH (field_to_delete)-[:IS_RELATED {branch: default_branch}]-(rel:Relationship)
|
|
144
|
+
WITH DISTINCT field_to_delete, rel
|
|
145
|
+
MATCH (rel)-[r {branch: default_branch}]-()
|
|
146
|
+
DELETE r
|
|
147
|
+
WITH field_to_delete, rel
|
|
148
|
+
OPTIONAL MATCH (rel)
|
|
149
|
+
WHERE NOT exists((rel)--())
|
|
150
|
+
DELETE rel
|
|
151
|
+
WITH DISTINCT field_to_delete
|
|
152
|
+
MATCH (field_to_delete)-[:HAS_ATTRIBUTE {branch: default_branch}]->(attr:Attribute)
|
|
153
|
+
MATCH (attr)-[r {branch: default_branch}]-()
|
|
154
|
+
DELETE r
|
|
155
|
+
WITH field_to_delete, attr
|
|
156
|
+
OPTIONAL MATCH (attr)
|
|
157
|
+
WHERE NOT exists((attr)--())
|
|
158
|
+
DELETE attr
|
|
159
|
+
WITH DISTINCT field_to_delete
|
|
160
|
+
OPTIONAL MATCH (field_to_delete)
|
|
161
|
+
WHERE NOT exists((field_to_delete)--())
|
|
162
|
+
DELETE field_to_delete
|
|
163
|
+
}
|
|
164
|
+
"""
|
|
165
|
+
self.add_to_query(query)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def display_duplicate_schema_fields(duplicate_schema_fields: list[SchemaFieldDetails]) -> None:
|
|
169
|
+
console = Console()
|
|
170
|
+
|
|
171
|
+
table = Table(title="Duplicate Schema Fields on Default Branch")
|
|
172
|
+
|
|
173
|
+
table.add_column("Schema Kind")
|
|
174
|
+
table.add_column("Schema UUID")
|
|
175
|
+
table.add_column("Field Name")
|
|
176
|
+
table.add_column("Field Type")
|
|
177
|
+
|
|
178
|
+
for duplicate_schema_field in duplicate_schema_fields:
|
|
179
|
+
table.add_row(
|
|
180
|
+
duplicate_schema_field.schema_kind,
|
|
181
|
+
duplicate_schema_field.schema_uuid,
|
|
182
|
+
duplicate_schema_field.field_name,
|
|
183
|
+
duplicate_schema_field.field_type.value,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
console.print(table)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
async def clean_duplicate_schema_fields(db: InfrahubDatabase, fix: bool = False) -> bool:
|
|
190
|
+
"""
|
|
191
|
+
Identify any attributes or relationships that are duplicated in a schema on the default branch
|
|
192
|
+
If fix is True, runs cypher queries to hard delete the earlier duplicate
|
|
193
|
+
"""
|
|
194
|
+
|
|
195
|
+
duplicate_schema_fields_query = await GetDuplicateSchemaFields.init(db=db)
|
|
196
|
+
await duplicate_schema_fields_query.execute(db=db)
|
|
197
|
+
duplicate_schema_fields = duplicate_schema_fields_query.get_schema_field_details()
|
|
198
|
+
|
|
199
|
+
if not duplicate_schema_fields:
|
|
200
|
+
rprint(f"{SUCCESS_BADGE} No duplicate schema fields found")
|
|
201
|
+
return True
|
|
202
|
+
|
|
203
|
+
display_duplicate_schema_fields(duplicate_schema_fields)
|
|
204
|
+
|
|
205
|
+
if not fix:
|
|
206
|
+
rprint(f"{FAILED_BADGE} Use the --fix flag to fix the duplicate schema fields")
|
|
207
|
+
return False
|
|
208
|
+
|
|
209
|
+
fix_duplicate_schema_fields_query = await FixDuplicateSchemaFields.init(db=db)
|
|
210
|
+
await fix_duplicate_schema_fields_query.execute(db=db)
|
|
211
|
+
rprint(f"{SUCCESS_BADGE} Duplicate schema fields deleted from the default branch")
|
|
212
|
+
return True
|
|
@@ -104,7 +104,7 @@ async def process_transform(
|
|
|
104
104
|
) # type: ignore[misc]
|
|
105
105
|
|
|
106
106
|
data = await client.query_gql_query(
|
|
107
|
-
name=transform.query.
|
|
107
|
+
name=transform.query.id,
|
|
108
108
|
branch_name=branch_name,
|
|
109
109
|
variables={"id": object_id},
|
|
110
110
|
update_group=True,
|
|
@@ -560,7 +560,7 @@ class RelationshipChangelogGetter:
|
|
|
560
560
|
|
|
561
561
|
for peer in relationship.peers:
|
|
562
562
|
if peer.peer_status == DiffAction.ADDED:
|
|
563
|
-
peer_schema = schema_branch.get(name=peer.peer_kind)
|
|
563
|
+
peer_schema = schema_branch.get(name=peer.peer_kind, duplicate=False)
|
|
564
564
|
secondaries.extend(
|
|
565
565
|
self._process_added_peers(
|
|
566
566
|
peer_id=peer.peer_id,
|
|
@@ -572,7 +572,7 @@ class RelationshipChangelogGetter:
|
|
|
572
572
|
)
|
|
573
573
|
|
|
574
574
|
elif peer.peer_status == DiffAction.REMOVED:
|
|
575
|
-
peer_schema = schema_branch.get(name=peer.peer_kind)
|
|
575
|
+
peer_schema = schema_branch.get(name=peer.peer_kind, duplicate=False)
|
|
576
576
|
secondaries.extend(
|
|
577
577
|
self._process_removed_peers(
|
|
578
578
|
peer_id=peer.peer_id,
|
|
@@ -44,7 +44,7 @@ CALL (source_artifact) {
|
|
|
44
44
|
MATCH (source_artifact)-[r:IS_PART_OF]->(:Root)
|
|
45
45
|
WHERE %(source_branch_filter)s
|
|
46
46
|
RETURN r AS root_rel
|
|
47
|
-
ORDER BY r.branch_level DESC, r.from DESC
|
|
47
|
+
ORDER BY r.branch_level DESC, r.from DESC, r.status ASC
|
|
48
48
|
LIMIT 1
|
|
49
49
|
}
|
|
50
50
|
WITH source_artifact, root_rel
|
|
@@ -61,7 +61,7 @@ CALL (source_artifact) {
|
|
|
61
61
|
target_node,
|
|
62
62
|
(rrel1.status = "active" AND rrel2.status = "active") AS target_is_active,
|
|
63
63
|
$source_branch_name IN [rrel1.branch, rrel2.branch] AS target_on_source_branch
|
|
64
|
-
ORDER BY rrel1.branch_level DESC,
|
|
64
|
+
ORDER BY rrel1.branch_level DESC, rrel2.branch_level DESC, rrel1.from DESC, rrel2.from DESC, rrel1.status ASC, rrel2.status ASC
|
|
65
65
|
LIMIT 1
|
|
66
66
|
}
|
|
67
67
|
// -----------------------
|
|
@@ -75,7 +75,7 @@ CALL (source_artifact) {
|
|
|
75
75
|
definition_node,
|
|
76
76
|
(rrel1.status = "active" AND rrel2.status = "active") AS definition_is_active,
|
|
77
77
|
$source_branch_name IN [rrel1.branch, rrel2.branch] AS definition_on_source_branch
|
|
78
|
-
ORDER BY rrel1.branch_level DESC,
|
|
78
|
+
ORDER BY rrel1.branch_level DESC, rrel2.branch_level DESC, rrel1.from DESC, rrel2.from DESC, rrel1.status ASC, rrel2.status ASC
|
|
79
79
|
LIMIT 1
|
|
80
80
|
}
|
|
81
81
|
// -----------------------
|
|
@@ -89,7 +89,8 @@ CALL (source_artifact) {
|
|
|
89
89
|
attr_val.value AS checksum,
|
|
90
90
|
(attr_rel.status = "active" AND value_rel.status = "active") AS checksum_is_active,
|
|
91
91
|
$source_branch_name IN [attr_rel.branch, value_rel.branch] AS checksum_on_source_branch
|
|
92
|
-
ORDER BY value_rel.branch_level DESC, attr_rel.branch_level DESC, value_rel.from DESC, attr_rel.from DESC
|
|
92
|
+
ORDER BY value_rel.branch_level DESC, attr_rel.branch_level DESC, value_rel.from DESC, attr_rel.from DESC,
|
|
93
|
+
value_rel.status ASC, attr_rel.status ASC
|
|
93
94
|
LIMIT 1
|
|
94
95
|
}
|
|
95
96
|
// -----------------------
|
|
@@ -103,7 +104,8 @@ CALL (source_artifact) {
|
|
|
103
104
|
attr_val.value AS storage_id,
|
|
104
105
|
(attr_rel.status = "active" AND value_rel.status = "active") AS storage_id_is_active,
|
|
105
106
|
$source_branch_name IN [attr_rel.branch, value_rel.branch] AS storage_id_on_source_branch
|
|
106
|
-
ORDER BY value_rel.branch_level DESC, attr_rel.branch_level DESC, value_rel.from DESC, attr_rel.from DESC
|
|
107
|
+
ORDER BY value_rel.branch_level DESC, attr_rel.branch_level DESC, value_rel.from DESC, attr_rel.from DESC,
|
|
108
|
+
value_rel.status ASC, attr_rel.status ASC
|
|
107
109
|
LIMIT 1
|
|
108
110
|
}
|
|
109
111
|
WITH target_node, target_is_active, target_on_source_branch,
|
|
@@ -146,8 +148,9 @@ CALL (target_node, definition_node){
|
|
|
146
148
|
)
|
|
147
149
|
RETURN
|
|
148
150
|
target_artifact,
|
|
149
|
-
(trel1.status = "active" AND trel2.status = "active" AND drel1.status = "active" AND
|
|
150
|
-
ORDER BY trel1.from DESC, trel2.from DESC, drel1.from DESC, drel2.from DESC
|
|
151
|
+
(trel1.status = "active" AND trel2.status = "active" AND drel1.status = "active" AND drel2.status = "active") AS artifact_is_active
|
|
152
|
+
ORDER BY trel1.from DESC, trel2.from DESC, drel1.from DESC, drel2.from DESC,
|
|
153
|
+
trel1.status ASC, trel2.status ASC, drel1.status ASC, drel2.status ASC
|
|
151
154
|
LIMIT 1
|
|
152
155
|
}
|
|
153
156
|
WITH CASE
|
|
@@ -163,7 +166,7 @@ CALL (target_node, definition_node){
|
|
|
163
166
|
AND attr_rel.branch = $target_branch_name
|
|
164
167
|
AND value_rel.branch = $target_branch_name
|
|
165
168
|
RETURN attr_val.value AS checksum, (attr_rel.status = "active" AND value_rel.status = "active") AS checksum_is_active
|
|
166
|
-
ORDER BY value_rel.from DESC, attr_rel.from DESC
|
|
169
|
+
ORDER BY value_rel.from DESC, attr_rel.from DESC, value_rel.status ASC, attr_rel.status ASC
|
|
167
170
|
LIMIT 1
|
|
168
171
|
}
|
|
169
172
|
// -----------------------
|
|
@@ -175,7 +178,7 @@ CALL (target_node, definition_node){
|
|
|
175
178
|
AND attr_rel.branch = $target_branch_name
|
|
176
179
|
AND value_rel.branch = $target_branch_name
|
|
177
180
|
RETURN attr_val.value AS storage_id, (attr_rel.status = "active" AND value_rel.status = "active") AS storage_id_is_active
|
|
178
|
-
ORDER BY value_rel.from DESC, attr_rel.from DESC
|
|
181
|
+
ORDER BY value_rel.from DESC, attr_rel.from DESC, value_rel.status ASC, attr_rel.status ASC
|
|
179
182
|
LIMIT 1
|
|
180
183
|
}
|
|
181
184
|
RETURN target_artifact,
|