infrahub-server 1.6.0b0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- infrahub/api/oauth2.py +33 -6
- infrahub/api/oidc.py +36 -6
- infrahub/auth.py +11 -0
- infrahub/auth_pkce.py +41 -0
- infrahub/config.py +9 -3
- infrahub/core/branch/models.py +3 -2
- infrahub/core/branch/tasks.py +6 -1
- infrahub/core/changelog/models.py +2 -2
- infrahub/core/constants/__init__.py +1 -0
- infrahub/core/graph/__init__.py +1 -1
- infrahub/core/integrity/object_conflict/conflict_recorder.py +1 -1
- infrahub/core/manager.py +36 -31
- infrahub/core/migrations/graph/__init__.py +4 -0
- infrahub/core/migrations/graph/m041_deleted_dup_edges.py +30 -12
- infrahub/core/migrations/graph/m047_backfill_or_null_display_label.py +606 -0
- infrahub/core/migrations/graph/m048_undelete_rel_props.py +161 -0
- infrahub/core/models.py +5 -6
- infrahub/core/node/__init__.py +16 -13
- infrahub/core/node/create.py +36 -8
- infrahub/core/node/proposed_change.py +5 -3
- infrahub/core/node/standard.py +1 -1
- infrahub/core/protocols.py +1 -7
- infrahub/core/query/attribute.py +1 -1
- infrahub/core/query/node.py +9 -5
- infrahub/core/relationship/model.py +21 -4
- infrahub/core/schema/generic_schema.py +1 -1
- infrahub/core/schema/manager.py +8 -3
- infrahub/core/schema/schema_branch.py +35 -16
- infrahub/core/validators/attribute/choices.py +2 -2
- infrahub/core/validators/determiner.py +3 -6
- infrahub/database/__init__.py +1 -1
- infrahub/git/base.py +2 -3
- infrahub/git/models.py +13 -0
- infrahub/git/tasks.py +23 -19
- infrahub/git/utils.py +16 -9
- infrahub/graphql/app.py +6 -6
- infrahub/graphql/loaders/peers.py +6 -0
- infrahub/graphql/mutations/action.py +15 -7
- infrahub/graphql/mutations/hfid.py +1 -1
- infrahub/graphql/mutations/profile.py +8 -1
- infrahub/graphql/mutations/repository.py +3 -3
- infrahub/graphql/mutations/schema.py +4 -4
- infrahub/graphql/mutations/webhook.py +2 -2
- infrahub/graphql/queries/resource_manager.py +2 -3
- infrahub/graphql/queries/search.py +2 -3
- infrahub/graphql/resolvers/ipam.py +20 -0
- infrahub/graphql/resolvers/many_relationship.py +12 -11
- infrahub/graphql/resolvers/resolver.py +6 -2
- infrahub/graphql/resolvers/single_relationship.py +1 -11
- infrahub/log.py +1 -1
- infrahub/message_bus/messages/__init__.py +0 -12
- infrahub/profiles/node_applier.py +9 -0
- infrahub/proposed_change/branch_diff.py +1 -1
- infrahub/proposed_change/tasks.py +1 -1
- infrahub/repositories/create_repository.py +3 -3
- infrahub/task_manager/models.py +1 -1
- infrahub/task_manager/task.py +5 -5
- infrahub/trigger/setup.py +6 -9
- infrahub/utils.py +18 -0
- infrahub/validators/tasks.py +1 -1
- infrahub/workers/infrahub_async.py +7 -6
- infrahub_sdk/client.py +113 -1
- infrahub_sdk/ctl/AGENTS.md +67 -0
- infrahub_sdk/ctl/branch.py +175 -1
- infrahub_sdk/ctl/check.py +3 -3
- infrahub_sdk/ctl/cli_commands.py +9 -9
- infrahub_sdk/ctl/generator.py +2 -2
- infrahub_sdk/ctl/graphql.py +1 -2
- infrahub_sdk/ctl/importer.py +1 -2
- infrahub_sdk/ctl/repository.py +6 -49
- infrahub_sdk/ctl/task.py +2 -4
- infrahub_sdk/ctl/utils.py +2 -2
- infrahub_sdk/ctl/validate.py +1 -2
- infrahub_sdk/diff.py +80 -3
- infrahub_sdk/graphql/constants.py +14 -1
- infrahub_sdk/graphql/renderers.py +5 -1
- infrahub_sdk/node/attribute.py +0 -1
- infrahub_sdk/node/constants.py +3 -1
- infrahub_sdk/node/node.py +303 -3
- infrahub_sdk/node/related_node.py +1 -2
- infrahub_sdk/node/relationship.py +1 -2
- infrahub_sdk/protocols_base.py +0 -1
- infrahub_sdk/pytest_plugin/AGENTS.md +67 -0
- infrahub_sdk/schema/__init__.py +0 -3
- infrahub_sdk/timestamp.py +7 -7
- {infrahub_server-1.6.0b0.dist-info → infrahub_server-1.6.2.dist-info}/METADATA +2 -3
- {infrahub_server-1.6.0b0.dist-info → infrahub_server-1.6.2.dist-info}/RECORD +91 -86
- {infrahub_server-1.6.0b0.dist-info → infrahub_server-1.6.2.dist-info}/WHEEL +1 -1
- infrahub_testcontainers/container.py +2 -2
- {infrahub_server-1.6.0b0.dist-info → infrahub_server-1.6.2.dist-info}/entry_points.txt +0 -0
- {infrahub_server-1.6.0b0.dist-info → infrahub_server-1.6.2.dist-info}/licenses/LICENSE.txt +0 -0
infrahub/api/oauth2.py
CHANGED
|
@@ -12,10 +12,12 @@ from opentelemetry import trace
|
|
|
12
12
|
from infrahub import config, models
|
|
13
13
|
from infrahub.api.dependencies import get_db
|
|
14
14
|
from infrahub.auth import (
|
|
15
|
+
SSOStateCache,
|
|
15
16
|
get_groups_from_provider,
|
|
16
17
|
signin_sso_account,
|
|
17
18
|
validate_auth_response,
|
|
18
19
|
)
|
|
20
|
+
from infrahub.auth_pkce import compute_code_challenge, generate_code_verifier
|
|
19
21
|
from infrahub.exceptions import ProcessingError
|
|
20
22
|
from infrahub.log import get_logger
|
|
21
23
|
from infrahub.message_bus.types import KVTTL
|
|
@@ -42,6 +44,7 @@ async def authorize(request: Request, provider_name: str, final_url: str | None
|
|
|
42
44
|
with trace.get_tracer(__name__).start_as_current_span("sso_oauth2_client_configuration") as span:
|
|
43
45
|
span.set_attribute("provider_name", provider_name)
|
|
44
46
|
span.set_attribute("scopes", provider.scopes)
|
|
47
|
+
span.set_attribute("pkce_enabled", provider.pkce_enabled)
|
|
45
48
|
|
|
46
49
|
client = AsyncOAuth2Client(
|
|
47
50
|
client_id=provider.client_id,
|
|
@@ -52,14 +55,32 @@ async def authorize(request: Request, provider_name: str, final_url: str | None
|
|
|
52
55
|
redirect_uri = _get_redirect_url(request=request, provider_name=provider_name)
|
|
53
56
|
final_url = final_url or config.SETTINGS.main.public_url or str(request.base_url)
|
|
54
57
|
|
|
58
|
+
# Generate PKCE parameters if enabled
|
|
59
|
+
code_verifier = None
|
|
60
|
+
pkce_params: dict[str, str] = {}
|
|
61
|
+
if provider.pkce_enabled:
|
|
62
|
+
code_verifier = generate_code_verifier()
|
|
63
|
+
code_challenge = compute_code_challenge(code_verifier)
|
|
64
|
+
pkce_params = {
|
|
65
|
+
"code_challenge": code_challenge,
|
|
66
|
+
"code_challenge_method": "S256",
|
|
67
|
+
}
|
|
68
|
+
|
|
55
69
|
authorization_uri, state = client.create_authorization_url(
|
|
56
|
-
url=provider.authorization_url,
|
|
70
|
+
url=provider.authorization_url,
|
|
71
|
+
redirect_uri=redirect_uri,
|
|
72
|
+
scope=provider.scopes,
|
|
73
|
+
final_url=final_url,
|
|
74
|
+
**pkce_params,
|
|
57
75
|
)
|
|
58
76
|
|
|
59
77
|
service: InfrahubServices = request.app.state.service
|
|
60
78
|
|
|
79
|
+
cache_data = SSOStateCache(final_url=final_url, code_verifier=code_verifier)
|
|
61
80
|
await service.cache.set(
|
|
62
|
-
key=f"security:oauth2:provider:{provider_name}:state:{state}",
|
|
81
|
+
key=f"security:oauth2:provider:{provider_name}:state:{state}",
|
|
82
|
+
value=cache_data.model_dump_json(),
|
|
83
|
+
expires=KVTTL.TWO_HOURS,
|
|
63
84
|
)
|
|
64
85
|
|
|
65
86
|
if config.SETTINGS.dev.frontend_redirect_sso:
|
|
@@ -82,13 +103,15 @@ async def token(
|
|
|
82
103
|
service: InfrahubServices = request.app.state.service
|
|
83
104
|
|
|
84
105
|
cache_key = f"security:oauth2:provider:{provider_name}:state:{state}"
|
|
85
|
-
|
|
106
|
+
cached_data = await service.cache.get(key=cache_key)
|
|
86
107
|
await service.cache.delete(key=cache_key)
|
|
87
108
|
|
|
88
|
-
if not
|
|
109
|
+
if not cached_data:
|
|
89
110
|
raise ProcessingError(message="Invalid 'state' parameter")
|
|
90
111
|
|
|
91
|
-
|
|
112
|
+
sso_state = SSOStateCache.model_validate_json(cached_data)
|
|
113
|
+
|
|
114
|
+
token_data: dict[str, str | None] = {
|
|
92
115
|
"code": code,
|
|
93
116
|
"client_id": provider.client_id,
|
|
94
117
|
"client_secret": provider.client_secret,
|
|
@@ -96,6 +119,10 @@ async def token(
|
|
|
96
119
|
"grant_type": "authorization_code",
|
|
97
120
|
}
|
|
98
121
|
|
|
122
|
+
# Add code_verifier if PKCE was used
|
|
123
|
+
if sso_state.code_verifier:
|
|
124
|
+
token_data["code_verifier"] = sso_state.code_verifier
|
|
125
|
+
|
|
99
126
|
token_response = await service.http.post(provider.token_url, data=token_data)
|
|
100
127
|
validate_auth_response(response=token_response, provider_type="OAuth 2.0")
|
|
101
128
|
|
|
@@ -139,5 +166,5 @@ async def token(
|
|
|
139
166
|
)
|
|
140
167
|
|
|
141
168
|
return models.UserTokenWithUrl(
|
|
142
|
-
access_token=user_token.access_token, refresh_token=user_token.refresh_token, final_url=
|
|
169
|
+
access_token=user_token.access_token, refresh_token=user_token.refresh_token, final_url=sso_state.final_url
|
|
143
170
|
)
|
infrahub/api/oidc.py
CHANGED
|
@@ -14,10 +14,12 @@ from pydantic import BaseModel, HttpUrl
|
|
|
14
14
|
from infrahub import config, models
|
|
15
15
|
from infrahub.api.dependencies import get_db
|
|
16
16
|
from infrahub.auth import (
|
|
17
|
+
SSOStateCache,
|
|
17
18
|
get_groups_from_provider,
|
|
18
19
|
signin_sso_account,
|
|
19
20
|
validate_auth_response,
|
|
20
21
|
)
|
|
22
|
+
from infrahub.auth_pkce import compute_code_challenge, generate_code_verifier
|
|
21
23
|
from infrahub.exceptions import ProcessingError
|
|
22
24
|
from infrahub.log import get_logger
|
|
23
25
|
from infrahub.message_bus.types import KVTTL
|
|
@@ -58,6 +60,10 @@ class OIDCDiscoveryConfig(BaseModel):
|
|
|
58
60
|
tls_client_certificate_bound_access_tokens: bool | None = None
|
|
59
61
|
mtls_endpoint_aliases: dict[str, HttpUrl] | None = None
|
|
60
62
|
|
|
63
|
+
@property
|
|
64
|
+
def supports_pkce(self) -> bool:
|
|
65
|
+
return "S256" in (self.code_challenge_methods_supported or [])
|
|
66
|
+
|
|
61
67
|
|
|
62
68
|
def _get_redirect_url(request: Request, provider_name: str) -> str:
|
|
63
69
|
"""Return public redirect URL."""
|
|
@@ -74,10 +80,14 @@ async def authorize(request: Request, provider_name: str, final_url: str | None
|
|
|
74
80
|
validate_auth_response(response=response, provider_type="OIDC")
|
|
75
81
|
oidc_config = OIDCDiscoveryConfig(**response.json())
|
|
76
82
|
|
|
83
|
+
pkce_supported = oidc_config.supports_pkce
|
|
84
|
+
|
|
77
85
|
with trace.get_tracer(__name__).start_as_current_span("sso_oauth2_client_configuration") as span:
|
|
78
86
|
span.set_attribute("provider_name", provider_name)
|
|
79
87
|
span.set_attribute("scopes", provider.scopes)
|
|
80
88
|
span.set_attribute("discovery_url", provider.discovery_url)
|
|
89
|
+
span.set_attribute("pkce_enabled", provider.pkce_enabled)
|
|
90
|
+
span.set_attribute("pkce_supported", pkce_supported)
|
|
81
91
|
|
|
82
92
|
client = AsyncOAuth2Client(
|
|
83
93
|
client_id=provider.client_id,
|
|
@@ -88,12 +98,26 @@ async def authorize(request: Request, provider_name: str, final_url: str | None
|
|
|
88
98
|
redirect_uri = _get_redirect_url(request=request, provider_name=provider_name)
|
|
89
99
|
final_url = final_url or config.SETTINGS.main.public_url or str(request.base_url)
|
|
90
100
|
|
|
101
|
+
# Generate PKCE parameters if enabled and supported by provider
|
|
102
|
+
code_verifier = None
|
|
103
|
+
pkce_params: dict[str, str] = {}
|
|
104
|
+
if provider.pkce_enabled and pkce_supported:
|
|
105
|
+
code_verifier = generate_code_verifier()
|
|
106
|
+
code_challenge = compute_code_challenge(code_verifier)
|
|
107
|
+
pkce_params = {
|
|
108
|
+
"code_challenge": code_challenge,
|
|
109
|
+
"code_challenge_method": "S256",
|
|
110
|
+
}
|
|
111
|
+
|
|
91
112
|
authorization_uri, state = client.create_authorization_url(
|
|
92
|
-
url=str(oidc_config.authorization_endpoint), redirect_uri=redirect_uri, scope=provider.scopes
|
|
113
|
+
url=str(oidc_config.authorization_endpoint), redirect_uri=redirect_uri, scope=provider.scopes, **pkce_params
|
|
93
114
|
)
|
|
94
115
|
|
|
116
|
+
cache_data = SSOStateCache(final_url=final_url, code_verifier=code_verifier)
|
|
95
117
|
await service.cache.set(
|
|
96
|
-
key=f"security:oidc:provider:{provider_name}:state:{state}",
|
|
118
|
+
key=f"security:oidc:provider:{provider_name}:state:{state}",
|
|
119
|
+
value=cache_data.model_dump_json(),
|
|
120
|
+
expires=KVTTL.TWO_HOURS,
|
|
97
121
|
)
|
|
98
122
|
|
|
99
123
|
if config.SETTINGS.dev.frontend_redirect_sso:
|
|
@@ -116,13 +140,15 @@ async def token(
|
|
|
116
140
|
service: InfrahubServices = request.app.state.service
|
|
117
141
|
|
|
118
142
|
cache_key = f"security:oidc:provider:{provider_name}:state:{state}"
|
|
119
|
-
|
|
143
|
+
cached_data = await service.cache.get(key=cache_key)
|
|
120
144
|
await service.cache.delete(key=cache_key)
|
|
121
145
|
|
|
122
|
-
if not
|
|
146
|
+
if not cached_data:
|
|
123
147
|
raise ProcessingError(message="Invalid 'state' parameter")
|
|
124
148
|
|
|
125
|
-
|
|
149
|
+
sso_state = SSOStateCache.model_validate_json(cached_data)
|
|
150
|
+
|
|
151
|
+
token_data: dict[str, str | None] = {
|
|
126
152
|
"code": code,
|
|
127
153
|
"client_id": provider.client_id,
|
|
128
154
|
"client_secret": provider.client_secret,
|
|
@@ -130,6 +156,10 @@ async def token(
|
|
|
130
156
|
"grant_type": "authorization_code",
|
|
131
157
|
}
|
|
132
158
|
|
|
159
|
+
# Add code_verifier if PKCE was used
|
|
160
|
+
if sso_state.code_verifier:
|
|
161
|
+
token_data["code_verifier"] = sso_state.code_verifier
|
|
162
|
+
|
|
133
163
|
discovery_response = await service.http.get(url=provider.discovery_url)
|
|
134
164
|
validate_auth_response(response=discovery_response, provider_type="OIDC")
|
|
135
165
|
|
|
@@ -183,7 +213,7 @@ async def token(
|
|
|
183
213
|
)
|
|
184
214
|
|
|
185
215
|
return models.UserTokenWithUrl(
|
|
186
|
-
access_token=user_token.access_token, refresh_token=user_token.refresh_token, final_url=
|
|
216
|
+
access_token=user_token.access_token, refresh_token=user_token.refresh_token, final_url=sso_state.final_url
|
|
187
217
|
)
|
|
188
218
|
|
|
189
219
|
|
infrahub/auth.py
CHANGED
|
@@ -51,6 +51,17 @@ class AccountSession(BaseModel):
|
|
|
51
51
|
return self.auth_type == AuthType.JWT
|
|
52
52
|
|
|
53
53
|
|
|
54
|
+
class SSOStateCache(BaseModel):
|
|
55
|
+
"""Cache data stored during OAuth2/OIDC authorization flow.
|
|
56
|
+
|
|
57
|
+
This model is used to store state information between the authorization
|
|
58
|
+
request and the token exchange, including PKCE code_verifier when enabled.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
final_url: str
|
|
62
|
+
code_verifier: str | None = None
|
|
63
|
+
|
|
64
|
+
|
|
54
65
|
async def validate_active_account(db: InfrahubDatabase, account_id: str) -> None:
|
|
55
66
|
account = await NodeManager.get_one(db=db, kind=CoreGenericAccount, id=account_id, raise_on_error=True)
|
|
56
67
|
if account.status.value != AccountStatus.ACTIVE.value:
|
infrahub/auth_pkce.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""PKCE (RFC 7636) utilities for OAuth2/OIDC authentication.
|
|
2
|
+
|
|
3
|
+
This module provides functions to generate code verifiers and compute
|
|
4
|
+
code challenges for the Proof Key for Code Exchange (PKCE) extension.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import base64
|
|
10
|
+
import hashlib
|
|
11
|
+
import secrets
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def generate_code_verifier() -> str:
|
|
15
|
+
"""Generate a cryptographically random code verifier.
|
|
16
|
+
|
|
17
|
+
The code verifier is a high-entropy cryptographic random string using
|
|
18
|
+
the unreserved characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~",
|
|
19
|
+
with a minimum length of 43 characters and a maximum length of 128
|
|
20
|
+
characters.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
A 43-character URL-safe string (256 bits of entropy).
|
|
24
|
+
"""
|
|
25
|
+
return secrets.token_urlsafe(32)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def compute_code_challenge(code_verifier: str) -> str:
|
|
29
|
+
"""Compute S256 code challenge from verifier.
|
|
30
|
+
|
|
31
|
+
Implements the S256 code challenge method as defined in RFC 7636:
|
|
32
|
+
code_challenge = BASE64URL(SHA256(code_verifier))
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
code_verifier: The code verifier string.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Base64URL-encoded SHA256 hash without padding.
|
|
39
|
+
"""
|
|
40
|
+
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
|
|
41
|
+
return base64.urlsafe_b64encode(digest).decode("ascii").rstrip("=")
|
infrahub/config.py
CHANGED
|
@@ -327,7 +327,7 @@ class DevelopmentSettings(BaseSettings):
|
|
|
327
327
|
description="Allow enterprise configuration in development mode, this will not enable the features just allow the configuration.",
|
|
328
328
|
)
|
|
329
329
|
git_credential_helper: str = Field(
|
|
330
|
-
default="
|
|
330
|
+
default="infrahub-git-credential",
|
|
331
331
|
description="Location of git credential helper",
|
|
332
332
|
)
|
|
333
333
|
|
|
@@ -585,11 +585,14 @@ class SecurityOIDCBaseSettings(BaseSettings):
|
|
|
585
585
|
icon: str = Field(default="mdi:account-key")
|
|
586
586
|
display_label: str = Field(default="Single Sign on")
|
|
587
587
|
userinfo_method: UserInfoMethod = Field(default=UserInfoMethod.GET)
|
|
588
|
+
pkce_enabled: bool = Field(
|
|
589
|
+
default=True, description="Enable PKCE (RFC 7636) with S256 method for authorization code flow"
|
|
590
|
+
)
|
|
588
591
|
|
|
589
592
|
|
|
590
593
|
class SecurityOIDCSettings(SecurityOIDCBaseSettings):
|
|
591
594
|
client_id: str = Field(..., description="Client ID of the application created in the auth provider")
|
|
592
|
-
client_secret: str = Field(
|
|
595
|
+
client_secret: str | None = Field(default=None, description="Client secret as defined in auth provider")
|
|
593
596
|
discovery_url: str = Field(..., description="The OIDC discovery URL xyz/.well-known/openid-configuration")
|
|
594
597
|
scopes: list[str] = Field(default_factory=_default_scopes)
|
|
595
598
|
|
|
@@ -637,13 +640,16 @@ class SecurityOAuth2BaseSettings(BaseSettings):
|
|
|
637
640
|
|
|
638
641
|
icon: str = Field(default="mdi:account-key")
|
|
639
642
|
userinfo_method: UserInfoMethod = Field(default=UserInfoMethod.GET)
|
|
643
|
+
pkce_enabled: bool = Field(
|
|
644
|
+
default=True, description="Enable PKCE (RFC 7636) with S256 method for authorization code flow"
|
|
645
|
+
)
|
|
640
646
|
|
|
641
647
|
|
|
642
648
|
class SecurityOAuth2Settings(SecurityOAuth2BaseSettings):
|
|
643
649
|
"""Common base for Oauth2 providers"""
|
|
644
650
|
|
|
645
651
|
client_id: str = Field(..., description="Client ID of the application created in the auth provider")
|
|
646
|
-
client_secret: str = Field(
|
|
652
|
+
client_secret: str | None = Field(default=None, description="Client secret as defined in auth provider")
|
|
647
653
|
authorization_url: str = Field(...)
|
|
648
654
|
token_url: str = Field(...)
|
|
649
655
|
userinfo_url: str = Field(...)
|
infrahub/core/branch/models.py
CHANGED
|
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
|
3
3
|
import re
|
|
4
4
|
from typing import TYPE_CHECKING, Any, Optional, Self, Union, cast
|
|
5
5
|
|
|
6
|
-
from neo4j.graph import Node as Neo4jNode
|
|
7
6
|
from pydantic import Field, field_validator
|
|
8
7
|
|
|
9
8
|
from infrahub.core.branch.enums import BranchStatus
|
|
@@ -24,6 +23,8 @@ from infrahub.core.timestamp import Timestamp
|
|
|
24
23
|
from infrahub.exceptions import BranchNotFoundError, InitializationError, ValidationError
|
|
25
24
|
|
|
26
25
|
if TYPE_CHECKING:
|
|
26
|
+
from neo4j.graph import Node as Neo4jNode
|
|
27
|
+
|
|
27
28
|
from infrahub.database import InfrahubDatabase
|
|
28
29
|
|
|
29
30
|
|
|
@@ -168,7 +169,7 @@ class Branch(StandardNode):
|
|
|
168
169
|
)
|
|
169
170
|
await query.execute(db=db)
|
|
170
171
|
|
|
171
|
-
return [cls.from_db(node=cast(Neo4jNode, result.get("n"))) for result in query.get_results()]
|
|
172
|
+
return [cls.from_db(node=cast("Neo4jNode", result.get("n"))) for result in query.get_results()]
|
|
172
173
|
|
|
173
174
|
@classmethod
|
|
174
175
|
async def get_list_count(
|
infrahub/core/branch/tasks.py
CHANGED
|
@@ -157,7 +157,12 @@ async def rebase_branch(branch: str, context: InfrahubContext, send_events: bool
|
|
|
157
157
|
responses = await schema_validate_migrations(
|
|
158
158
|
message=SchemaValidateMigrationData(branch=obj, schema_branch=candidate_schema, constraints=constraints)
|
|
159
159
|
)
|
|
160
|
-
error_messages = [
|
|
160
|
+
error_messages = [
|
|
161
|
+
f"{violation.message} for constraint {response.constraint_name} {response.schema_path.field_name} {response.schema_path.property_name} and node {violation.node_id} {violation.node_kind}" # noqa: E501
|
|
162
|
+
for response in responses
|
|
163
|
+
for violation in response.violations
|
|
164
|
+
]
|
|
165
|
+
|
|
161
166
|
if error_messages:
|
|
162
167
|
raise ValidationError(",\n".join(error_messages))
|
|
163
168
|
|
|
@@ -290,7 +290,7 @@ class NodeChangelog(BaseModel):
|
|
|
290
290
|
name=relationship.schema.name
|
|
291
291
|
)
|
|
292
292
|
relationship_container = cast(
|
|
293
|
-
RelationshipCardinalityManyChangelog, self.relationships[relationship.schema.name]
|
|
293
|
+
"RelationshipCardinalityManyChangelog", self.relationships[relationship.schema.name]
|
|
294
294
|
)
|
|
295
295
|
|
|
296
296
|
relationship_container.add_new_peer(relationship=relationship)
|
|
@@ -311,7 +311,7 @@ class NodeChangelog(BaseModel):
|
|
|
311
311
|
name=relationship.schema.name
|
|
312
312
|
)
|
|
313
313
|
relationship_container = cast(
|
|
314
|
-
RelationshipCardinalityManyChangelog, self.relationships[relationship.schema.name]
|
|
314
|
+
"RelationshipCardinalityManyChangelog", self.relationships[relationship.schema.name]
|
|
315
315
|
)
|
|
316
316
|
relationship_container.remove_peer(
|
|
317
317
|
peer_id=relationship.get_peer_id(), peer_kind=relationship.get_peer_kind()
|
|
@@ -391,3 +391,4 @@ DEFAULT_REL_IDENTIFIER_LENGTH = 128
|
|
|
391
391
|
OBJECT_TEMPLATE_RELATIONSHIP_NAME = "object_template"
|
|
392
392
|
OBJECT_TEMPLATE_NAME_ATTR = "template_name"
|
|
393
393
|
PROFILE_NODE_RELATIONSHIP_IDENTIFIER = "node__profile"
|
|
394
|
+
PROFILE_TEMPLATE_RELATIONSHIP_IDENTIFIER = "template__profile"
|
infrahub/core/graph/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
GRAPH_VERSION =
|
|
1
|
+
GRAPH_VERSION = 48
|
|
@@ -24,7 +24,7 @@ class ObjectConflictValidatorRecorder:
|
|
|
24
24
|
)
|
|
25
25
|
except NodeNotFoundError:
|
|
26
26
|
return []
|
|
27
|
-
proposed_change = cast(CoreProposedChange, proposed_change)
|
|
27
|
+
proposed_change = cast("CoreProposedChange", proposed_change)
|
|
28
28
|
validator = await self.get_or_create_validator(proposed_change)
|
|
29
29
|
await self.initialize_validator(validator)
|
|
30
30
|
|
infrahub/core/manager.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Iterable, Literal, TypeVar, overload
|
|
|
5
5
|
|
|
6
6
|
from infrahub_sdk.utils import deep_merge_dict, is_valid_uuid
|
|
7
7
|
|
|
8
|
-
from infrahub.core.constants import RelationshipCardinality, RelationshipDirection
|
|
8
|
+
from infrahub.core.constants import InfrahubKind, RelationshipCardinality, RelationshipDirection
|
|
9
9
|
from infrahub.core.node import Node
|
|
10
10
|
from infrahub.core.node.delete_validator import NodeDeleteValidator
|
|
11
11
|
from infrahub.core.query.node import (
|
|
@@ -188,17 +188,22 @@ class NodeManager:
|
|
|
188
188
|
await query.execute(db=db)
|
|
189
189
|
node_ids = query.get_node_ids()
|
|
190
190
|
|
|
191
|
-
|
|
192
|
-
|
|
191
|
+
if (
|
|
192
|
+
fields
|
|
193
|
+
and "identifier" in fields
|
|
194
|
+
and node_schema.kind
|
|
195
|
+
in [
|
|
196
|
+
InfrahubKind.BASEPERMISSION,
|
|
197
|
+
InfrahubKind.GLOBALPERMISSION,
|
|
198
|
+
InfrahubKind.OBJECTPERMISSION,
|
|
199
|
+
]
|
|
200
|
+
):
|
|
201
|
+
# This is a workaround to ensure we are querying the right fields for permissions
|
|
202
|
+
# The identifier for permissions needs the same fields as the display label
|
|
193
203
|
schema_branch = db.schema.get_schema_branch(name=branch.name)
|
|
194
204
|
display_label_fields = schema_branch.generate_fields_for_display_label(name=node_schema.kind)
|
|
195
205
|
if display_label_fields:
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
if fields and "hfid" in fields and node_schema.human_friendly_id:
|
|
199
|
-
hfid_fields = node_schema.generate_fields_for_hfid()
|
|
200
|
-
if hfid_fields:
|
|
201
|
-
fields = deep_merge_dict(dicta=fields, dictb=hfid_fields)
|
|
206
|
+
deep_merge_dict(dicta=fields, dictb=display_label_fields)
|
|
202
207
|
|
|
203
208
|
response = await cls.get_many(
|
|
204
209
|
ids=node_ids,
|
|
@@ -300,6 +305,8 @@ class NodeManager:
|
|
|
300
305
|
branch: Branch | str | None = None,
|
|
301
306
|
branch_agnostic: bool = False,
|
|
302
307
|
fetch_peers: bool = False,
|
|
308
|
+
include_source: bool = False,
|
|
309
|
+
include_owner: bool = False,
|
|
303
310
|
) -> list[Relationship]:
|
|
304
311
|
branch = await registry.get_branch(branch=branch, db=db)
|
|
305
312
|
at = Timestamp(at)
|
|
@@ -324,24 +331,31 @@ class NodeManager:
|
|
|
324
331
|
if not peers_info:
|
|
325
332
|
return []
|
|
326
333
|
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
schema_branch = db.schema.get_schema_branch(name=branch.name)
|
|
331
|
-
display_label_fields = schema_branch.generate_fields_for_display_label(name=peer_schema.kind)
|
|
332
|
-
if display_label_fields:
|
|
333
|
-
fields = deep_merge_dict(dicta=fields, dictb=display_label_fields)
|
|
334
|
-
|
|
335
|
-
if fields and "hfid" in fields:
|
|
334
|
+
if fields and "identifier" in fields:
|
|
335
|
+
# This is a workaround to ensure we are querying the right fields for permissions
|
|
336
|
+
# The identifier for permissions needs the same fields as the display label
|
|
336
337
|
peer_schema = schema.get_peer_schema(db=db, branch=branch)
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
338
|
+
if peer_schema.kind in [
|
|
339
|
+
InfrahubKind.BASEPERMISSION,
|
|
340
|
+
InfrahubKind.GLOBALPERMISSION,
|
|
341
|
+
InfrahubKind.OBJECTPERMISSION,
|
|
342
|
+
]:
|
|
343
|
+
schema_branch = db.schema.get_schema_branch(name=branch.name)
|
|
344
|
+
display_label_fields = schema_branch.generate_fields_for_display_label(name=peer_schema.kind)
|
|
345
|
+
if display_label_fields:
|
|
346
|
+
deep_merge_dict(dicta=fields, dictb=display_label_fields)
|
|
340
347
|
|
|
341
348
|
if fetch_peers:
|
|
342
349
|
peer_ids = [peer.peer_id for peer in peers_info]
|
|
343
350
|
peer_nodes = await cls.get_many(
|
|
344
|
-
db=db,
|
|
351
|
+
db=db,
|
|
352
|
+
ids=peer_ids,
|
|
353
|
+
fields=fields,
|
|
354
|
+
at=at,
|
|
355
|
+
branch=branch,
|
|
356
|
+
branch_agnostic=branch_agnostic,
|
|
357
|
+
include_source=include_source,
|
|
358
|
+
include_owner=include_owner,
|
|
345
359
|
)
|
|
346
360
|
|
|
347
361
|
results = []
|
|
@@ -420,15 +434,6 @@ class NodeManager:
|
|
|
420
434
|
if not peers_ids:
|
|
421
435
|
return {}
|
|
422
436
|
|
|
423
|
-
hierarchy_schema = node_schema.get_hierarchy_schema(db=db, branch=branch)
|
|
424
|
-
|
|
425
|
-
# if display_label has been requested we need to ensure we are querying the right fields
|
|
426
|
-
if fields and "display_label" in fields:
|
|
427
|
-
schema_branch = db.schema.get_schema_branch(name=branch.name)
|
|
428
|
-
display_label_fields = schema_branch.generate_fields_for_display_label(name=hierarchy_schema.kind)
|
|
429
|
-
if display_label_fields:
|
|
430
|
-
fields = deep_merge_dict(dicta=fields, dictb=display_label_fields)
|
|
431
|
-
|
|
432
437
|
return await cls.get_many(
|
|
433
438
|
db=db, ids=peers_ids, fields=fields, at=at, branch=branch, include_owner=True, include_source=True
|
|
434
439
|
)
|
|
@@ -48,6 +48,8 @@ from .m043_create_hfid_display_label_in_db import Migration043
|
|
|
48
48
|
from .m044_backfill_hfid_display_label_in_db import Migration044
|
|
49
49
|
from .m045_backfill_hfid_display_label_in_db_profile_template import Migration045
|
|
50
50
|
from .m046_fill_agnostic_hfid_display_labels import Migration046
|
|
51
|
+
from .m047_backfill_or_null_display_label import Migration047
|
|
52
|
+
from .m048_undelete_rel_props import Migration048
|
|
51
53
|
|
|
52
54
|
if TYPE_CHECKING:
|
|
53
55
|
from ..shared import MigrationTypes
|
|
@@ -100,6 +102,8 @@ MIGRATIONS: list[type[MigrationTypes]] = [
|
|
|
100
102
|
Migration044,
|
|
101
103
|
Migration045,
|
|
102
104
|
Migration046,
|
|
105
|
+
Migration047,
|
|
106
|
+
Migration048,
|
|
103
107
|
]
|
|
104
108
|
|
|
105
109
|
|
|
@@ -68,33 +68,49 @@ DELETE added_e
|
|
|
68
68
|
self.add_to_query(query)
|
|
69
69
|
|
|
70
70
|
|
|
71
|
-
class
|
|
72
|
-
name = "
|
|
71
|
+
class DeleteDuplicatedRelationshipEdges(Query):
|
|
72
|
+
name = "delete_duplicated_relationship_edges_query"
|
|
73
73
|
type = QueryType.WRITE
|
|
74
74
|
insert_return = False
|
|
75
75
|
|
|
76
|
+
def __init__(self, migrated_kind_nodes_only: bool = True, **kwargs: Any):
|
|
77
|
+
self.migrated_kind_nodes_only = migrated_kind_nodes_only
|
|
78
|
+
super().__init__(**kwargs)
|
|
79
|
+
|
|
76
80
|
async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: # noqa: ARG002
|
|
77
|
-
|
|
81
|
+
if not self.migrated_kind_nodes_only:
|
|
82
|
+
relationship_filter_query = """
|
|
83
|
+
MATCH (rel:Relationship)
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
else:
|
|
87
|
+
relationship_filter_query = """
|
|
78
88
|
// ------------
|
|
79
89
|
// get UUIDs for migrated kind/inheritance nodes
|
|
80
90
|
// ------------
|
|
81
91
|
MATCH (n:Node)
|
|
82
92
|
WITH n.uuid AS node_uuid, count(*) AS num_nodes_with_uuid
|
|
83
93
|
WHERE num_nodes_with_uuid > 1
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
94
|
+
// ------------
|
|
95
|
+
// find any Relationships for these nodes
|
|
96
|
+
// ------------
|
|
97
|
+
MATCH (n:Node {uuid: node_uuid})-[:IS_RELATED]-(rel:Relationship)
|
|
98
|
+
WITH DISTINCT rel
|
|
99
|
+
|
|
100
|
+
"""
|
|
101
|
+
self.add_to_query(relationship_filter_query)
|
|
102
|
+
|
|
103
|
+
query = """
|
|
104
|
+
CALL (rel) {
|
|
105
|
+
MATCH (rel)-[e]-(peer)
|
|
91
106
|
WITH
|
|
107
|
+
elementId(rel) AS rel_element_id,
|
|
92
108
|
type(e) AS e_type,
|
|
93
109
|
e.branch AS e_branch,
|
|
94
110
|
e.from AS e_from,
|
|
95
111
|
e.to AS e_to,
|
|
96
112
|
e.status AS e_status,
|
|
97
|
-
|
|
113
|
+
elementId(peer) AS peer_element_id,
|
|
98
114
|
CASE
|
|
99
115
|
WHEN startNode(e) = rel THEN "out" ELSE "in"
|
|
100
116
|
END AS direction,
|
|
@@ -142,7 +158,9 @@ class Migration041(ArbitraryMigration):
|
|
|
142
158
|
rprint("done")
|
|
143
159
|
|
|
144
160
|
rprint("Deleting duplicate edges for migrated kind/inheritance nodes", end="...")
|
|
145
|
-
delete_duplicate_edges_query = await
|
|
161
|
+
delete_duplicate_edges_query = await DeleteDuplicatedRelationshipEdges.init(
|
|
162
|
+
db=db, migrated_kind_nodes_only=True
|
|
163
|
+
)
|
|
146
164
|
await delete_duplicate_edges_query.execute(db=db)
|
|
147
165
|
rprint("done")
|
|
148
166
|
|