infrahub-server 1.4.9__py3-none-any.whl → 1.4.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. infrahub/api/oauth2.py +13 -19
  2. infrahub/api/oidc.py +15 -21
  3. infrahub/artifacts/models.py +2 -1
  4. infrahub/auth.py +137 -3
  5. infrahub/cli/db.py +24 -0
  6. infrahub/cli/db_commands/clean_duplicate_schema_fields.py +212 -0
  7. infrahub/computed_attribute/tasks.py +1 -1
  8. infrahub/core/changelog/models.py +2 -2
  9. infrahub/core/diff/query/artifact.py +12 -9
  10. infrahub/core/ipam/utilization.py +1 -1
  11. infrahub/core/manager.py +6 -3
  12. infrahub/core/node/__init__.py +3 -1
  13. infrahub/core/node/constraints/attribute_uniqueness.py +3 -1
  14. infrahub/core/node/create.py +12 -3
  15. infrahub/core/registry.py +2 -2
  16. infrahub/core/relationship/constraints/count.py +1 -1
  17. infrahub/core/relationship/model.py +1 -1
  18. infrahub/core/schema/definitions/internal.py +4 -0
  19. infrahub/core/schema/manager.py +19 -1
  20. infrahub/core/schema/node_schema.py +4 -2
  21. infrahub/core/schema/schema_branch.py +8 -0
  22. infrahub/core/validators/determiner.py +12 -1
  23. infrahub/core/validators/relationship/peer.py +1 -1
  24. infrahub/core/validators/tasks.py +1 -1
  25. infrahub/generators/tasks.py +3 -7
  26. infrahub/git/integrator.py +1 -1
  27. infrahub/git/models.py +2 -1
  28. infrahub/git/repository.py +22 -5
  29. infrahub/git/tasks.py +14 -8
  30. infrahub/git/utils.py +123 -1
  31. infrahub/graphql/analyzer.py +1 -1
  32. infrahub/graphql/mutations/main.py +3 -3
  33. infrahub/graphql/mutations/schema.py +5 -5
  34. infrahub/message_bus/types.py +2 -1
  35. infrahub/middleware.py +26 -1
  36. infrahub/proposed_change/tasks.py +11 -12
  37. infrahub/server.py +12 -3
  38. infrahub/workers/dependencies.py +8 -1
  39. {infrahub_server-1.4.9.dist-info → infrahub_server-1.4.11.dist-info}/METADATA +17 -17
  40. {infrahub_server-1.4.9.dist-info → infrahub_server-1.4.11.dist-info}/RECORD +46 -45
  41. infrahub_testcontainers/container.py +1 -1
  42. infrahub_testcontainers/docker-compose-cluster.test.yml +1 -1
  43. infrahub_testcontainers/docker-compose.test.yml +1 -1
  44. {infrahub_server-1.4.9.dist-info → infrahub_server-1.4.11.dist-info}/LICENSE.txt +0 -0
  45. {infrahub_server-1.4.9.dist-info → infrahub_server-1.4.11.dist-info}/WHEEL +0 -0
  46. {infrahub_server-1.4.9.dist-info → infrahub_server-1.4.11.dist-info}/entry_points.txt +0 -0
infrahub/api/oauth2.py CHANGED
@@ -11,14 +11,16 @@ from opentelemetry import trace
11
11
 
12
12
  from infrahub import config, models
13
13
  from infrahub.api.dependencies import get_db
14
- from infrahub.auth import get_groups_from_provider, signin_sso_account
15
- from infrahub.exceptions import GatewayError, ProcessingError
14
+ from infrahub.auth import (
15
+ get_groups_from_provider,
16
+ signin_sso_account,
17
+ validate_auth_response,
18
+ )
19
+ from infrahub.exceptions import ProcessingError
16
20
  from infrahub.log import get_logger
17
21
  from infrahub.message_bus.types import KVTTL
18
22
 
19
23
  if TYPE_CHECKING:
20
- import httpx
21
-
22
24
  from infrahub.database import InfrahubDatabase
23
25
  from infrahub.services import InfrahubServices
24
26
 
@@ -95,7 +97,7 @@ async def token(
95
97
  }
96
98
 
97
99
  token_response = await service.http.post(provider.token_url, data=token_data)
98
- _validate_response(response=token_response)
100
+ validate_auth_response(response=token_response, provider_type="OAuth 2.0")
99
101
 
100
102
  with trace.get_tracer(__name__).start_as_current_span("sso_token_request") as span:
101
103
  span.set_attribute("token_request_data", ujson.dumps(token_response.json()))
@@ -107,12 +109,17 @@ async def token(
107
109
  else:
108
110
  userinfo_response = await service.http.post(provider.userinfo_url, headers=headers)
109
111
 
110
- _validate_response(response=userinfo_response)
112
+ validate_auth_response(response=userinfo_response, provider_type="OAuth 2.0")
111
113
  user_info = userinfo_response.json()
112
114
  sso_groups = user_info.get("groups", []) or await get_groups_from_provider(
113
115
  provider=provider, service=service, payload=payload, user_info=user_info
114
116
  )
115
117
 
118
+ log.info(
119
+ "SSO user authenticated",
120
+ body={"user_name": user_info.get("name"), "groups": sso_groups},
121
+ )
122
+
116
123
  if not sso_groups and config.SETTINGS.security.sso_user_default_group:
117
124
  sso_groups = [config.SETTINGS.security.sso_user_default_group]
118
125
 
@@ -134,16 +141,3 @@ async def token(
134
141
  return models.UserTokenWithUrl(
135
142
  access_token=user_token.access_token, refresh_token=user_token.refresh_token, final_url=stored_final_url
136
143
  )
137
-
138
-
139
- def _validate_response(response: httpx.Response) -> None:
140
- if 200 <= response.status_code <= 299:
141
- return
142
-
143
- log.error(
144
- "Invalid response from the OAuth provider",
145
- url=response.url,
146
- status_code=response.status_code,
147
- body=response.json(),
148
- )
149
- raise GatewayError(message="Invalid response from Authentication provider")
infrahub/api/oidc.py CHANGED
@@ -13,14 +13,16 @@ from pydantic import BaseModel, HttpUrl
13
13
 
14
14
  from infrahub import config, models
15
15
  from infrahub.api.dependencies import get_db
16
- from infrahub.auth import get_groups_from_provider, signin_sso_account
17
- from infrahub.exceptions import GatewayError, ProcessingError
16
+ from infrahub.auth import (
17
+ get_groups_from_provider,
18
+ signin_sso_account,
19
+ validate_auth_response,
20
+ )
21
+ from infrahub.exceptions import ProcessingError
18
22
  from infrahub.log import get_logger
19
23
  from infrahub.message_bus.types import KVTTL
20
24
 
21
25
  if TYPE_CHECKING:
22
- import httpx
23
-
24
26
  from infrahub.database import InfrahubDatabase
25
27
  from infrahub.services import InfrahubServices
26
28
 
@@ -69,7 +71,7 @@ async def authorize(request: Request, provider_name: str, final_url: str | None
69
71
  service: InfrahubServices = request.app.state.service
70
72
 
71
73
  response = await service.http.get(url=provider.discovery_url)
72
- _validate_response(response=response)
74
+ validate_auth_response(response=response, provider_type="OIDC")
73
75
  oidc_config = OIDCDiscoveryConfig(**response.json())
74
76
 
75
77
  with trace.get_tracer(__name__).start_as_current_span("sso_oauth2_client_configuration") as span:
@@ -129,12 +131,12 @@ async def token(
129
131
  }
130
132
 
131
133
  discovery_response = await service.http.get(url=provider.discovery_url)
132
- _validate_response(response=discovery_response)
134
+ validate_auth_response(response=discovery_response, provider_type="OIDC")
133
135
 
134
136
  oidc_config = OIDCDiscoveryConfig(**discovery_response.json())
135
137
 
136
138
  token_response = await service.http.post(str(oidc_config.token_endpoint), data=token_data)
137
- _validate_response(response=token_response)
139
+ validate_auth_response(response=token_response, provider_type="OIDC")
138
140
 
139
141
  with trace.get_tracer(__name__).start_as_current_span("sso_token_request") as span:
140
142
  span.set_attribute("token_request_data", ujson.dumps(token_response.json()))
@@ -147,7 +149,7 @@ async def token(
147
149
  else:
148
150
  userinfo_response = await service.http.post(str(oidc_config.userinfo_endpoint), headers=headers)
149
151
 
150
- _validate_response(response=userinfo_response)
152
+ validate_auth_response(response=userinfo_response, provider_type="OIDC")
151
153
  user_info: dict[str, Any] = userinfo_response.json()
152
154
  sso_groups = (
153
155
  user_info.get("groups")
@@ -157,6 +159,11 @@ async def token(
157
159
  or await get_groups_from_provider(provider=provider, service=service, payload=payload, user_info=user_info)
158
160
  )
159
161
 
162
+ log.info(
163
+ "SSO user authenticated",
164
+ body={"user_name": user_info.get("name"), "groups": sso_groups},
165
+ )
166
+
160
167
  if not sso_groups and config.SETTINGS.security.sso_user_default_group:
161
168
  sso_groups = [config.SETTINGS.security.sso_user_default_group]
162
169
 
@@ -180,19 +187,6 @@ async def token(
180
187
  )
181
188
 
182
189
 
183
- def _validate_response(response: httpx.Response) -> None:
184
- if 200 <= response.status_code <= 299:
185
- return
186
-
187
- log.error(
188
- "Invalid response from the OIDC provider",
189
- url=response.url,
190
- status_code=response.status_code,
191
- body=response.json(),
192
- )
193
- raise GatewayError(message="Invalid response from Authentication provider")
194
-
195
-
196
190
  async def _get_id_token_groups(
197
191
  oidc_config: OIDCDiscoveryConfig, service: InfrahubServices, payload: dict[str, Any], client_id: str
198
192
  ) -> list[str]:
@@ -25,7 +25,8 @@ class CheckArtifactCreate(BaseModel):
25
25
  target_kind: str = Field(..., description="The kind of the target object for this artifact")
26
26
  target_name: str = Field(..., description="Name of the artifact target")
27
27
  artifact_id: str | None = Field(default=None, description="The id of the artifact if it previously existed")
28
- query: str = Field(..., description="The name of the query to use when collecting data")
28
+ query: str = Field(..., description="The name of the query to use when collecting data") # Deprecated
29
+ query_id: str = Field(..., description="The id of the query to use when collecting data")
29
30
  timeout: int = Field(..., description="Timeout for requests used to generate this artifact")
30
31
  variables: dict = Field(..., description="Input variables when generating the artifact")
31
32
  validator_id: str = Field(..., description="The ID of the validator")
infrahub/auth.py CHANGED
@@ -3,27 +3,37 @@ from __future__ import annotations
3
3
  import uuid
4
4
  from datetime import datetime, timedelta, timezone
5
5
  from enum import Enum
6
- from typing import TYPE_CHECKING
6
+ from typing import TYPE_CHECKING, Any
7
7
 
8
8
  import bcrypt
9
9
  import jwt
10
10
  from pydantic import BaseModel
11
11
 
12
12
  from infrahub import config, models
13
- from infrahub.config import SecurityOAuth2Google, SecurityOAuth2Settings, SecurityOIDCGoogle, SecurityOIDCSettings
13
+ from infrahub.config import (
14
+ SecurityOAuth2Google,
15
+ SecurityOAuth2Settings,
16
+ SecurityOIDCGoogle,
17
+ SecurityOIDCSettings,
18
+ )
14
19
  from infrahub.core.account import validate_token
15
20
  from infrahub.core.constants import AccountStatus, InfrahubKind
16
21
  from infrahub.core.manager import NodeManager
17
22
  from infrahub.core.node import Node
18
23
  from infrahub.core.protocols import CoreAccount, CoreAccountGroup
19
24
  from infrahub.core.registry import registry
20
- from infrahub.exceptions import AuthorizationError, NodeNotFoundError
25
+ from infrahub.exceptions import AuthorizationError, GatewayError, NodeNotFoundError
26
+ from infrahub.log import get_logger
21
27
 
22
28
  if TYPE_CHECKING:
29
+ import httpx
30
+
23
31
  from infrahub.core.protocols import CoreGenericAccount
24
32
  from infrahub.database import InfrahubDatabase
25
33
  from infrahub.services import InfrahubServices
26
34
 
35
+ log = get_logger()
36
+
27
37
 
28
38
  class AuthType(str, Enum):
29
39
  NONE = "none"
@@ -256,3 +266,127 @@ async def get_groups_from_provider(
256
266
  return [membership["groupKey"]["id"] for membership in group_memberships["memberships"]]
257
267
 
258
268
  return []
269
+
270
+
271
+ def safe_get_response_body(response: httpx.Response, raise_error_on_empty_body: bool = True) -> str | dict[str, Any]:
272
+ """Safely extract response body from HTTP response. If the response body cannot be JSON parsed or is empty,
273
+ it raises a GatewayError.
274
+
275
+ Args:
276
+ response: The HTTP response object
277
+ raise_error_on_empty_body: Whether to raise an error if the response body is empty
278
+
279
+ Returns:
280
+ The response body as JSON dict if possible, otherwise as text
281
+
282
+ Raises:
283
+ GatewayError: When the response body cannot be parsed or is empty
284
+ """
285
+ # Try to parse as JSON first
286
+ try:
287
+ return response.json()
288
+ except Exception as json_error:
289
+ try:
290
+ # Try to get as text
291
+ text_body = response.text
292
+ if not text_body.strip() and raise_error_on_empty_body: # Check for empty or whitespace-only response
293
+ log.error(
294
+ "Empty response body from authentication provider",
295
+ url=str(response.url),
296
+ status_code=response.status_code,
297
+ )
298
+ raise GatewayError(message="Authentication provider returned an empty response") from json_error
299
+ except Exception:
300
+ log.error(
301
+ "Unable to read response body from authentication provider",
302
+ url=str(response.url),
303
+ status_code=response.status_code,
304
+ )
305
+ raise GatewayError(message="Unable to read response from authentication provider") from json_error
306
+
307
+ # Here it means we got a text response but not JSON
308
+ return text_body
309
+
310
+
311
+ def extract_auth_error_message(response_body: str | dict[str, Any], base_message: str) -> str:
312
+ """Extract error message from OAuth 2.0/OIDC provider response following RFC 6749.
313
+
314
+ Args:
315
+ response_body: The response body from the authentication provider
316
+ base_message: Base error message to use if no specific error is found
317
+
318
+ Returns:
319
+ Formatted error message with provider details if available
320
+ """
321
+ if not isinstance(response_body, dict):
322
+ return base_message
323
+
324
+ # RFC 6749 standard error response format
325
+ error_description = response_body.get("error_description")
326
+ error_code = response_body.get("error")
327
+
328
+ if error_description:
329
+ return f"{base_message}: {error_description}"
330
+ if error_code:
331
+ return f"{base_message}: {error_code}"
332
+
333
+ return base_message
334
+
335
+
336
+ def validate_auth_response(response: httpx.Response, provider_type: str = "authentication") -> None:
337
+ """Validate HTTP response from OAuth 2.0/OIDC provider and raise appropriate errors.
338
+
339
+ Args:
340
+ response: The HTTP response from the authentication provider
341
+ provider_type: Type of provider for logging (e.g., "OAuth 2.0", "OIDC")
342
+
343
+ Raises:
344
+ GatewayError: When the response indicates an error or invalid state
345
+ """
346
+ # If the status code is successful, simply return
347
+ if 200 <= response.status_code <= 299:
348
+ # Verify that we can read the response body safely and it is not empty
349
+ safe_get_response_body(response)
350
+ return
351
+
352
+ # Prepare variables with default values for logging
353
+ response_body = safe_get_response_body(response, raise_error_on_empty_body=False)
354
+ log_message: str = f"Unexpected response from {provider_type} provider"
355
+ base_msg: str = "Unexpected response from authentication provider."
356
+
357
+ # Handle specific HTTP status codes with appropriate error messages
358
+ match response.status_code:
359
+ case 400:
360
+ log_message = f"Bad request to {provider_type} provider"
361
+ base_msg = "Bad request to authentication provider. Please try again later or contact your administrator."
362
+
363
+ case 401:
364
+ log_message = f"Unauthorized request to {provider_type} provider"
365
+ base_msg = (
366
+ "Unauthorized request to authentication provider. Please try again later or contact your administrator."
367
+ )
368
+
369
+ case 403:
370
+ log_message = f"Forbidden request to {provider_type} provider"
371
+ base_msg = (
372
+ "Access forbidden by authentication provider. Please try again later or contact your administrator."
373
+ )
374
+
375
+ case 404:
376
+ log_message = f"Resource not found for {provider_type} provider"
377
+ base_msg = (
378
+ "Authentication provider endpoint not found. Please try again later or contact your administrator."
379
+ )
380
+
381
+ case 429:
382
+ log_message = f"Rate limited by {provider_type} provider"
383
+ base_msg = "Rate limited by authentication provider. Please try again later."
384
+
385
+ case status_code if 500 <= status_code <= 599:
386
+ log_message = f"Server error from {provider_type} provider"
387
+ base_msg = "Authentication provider is experiencing server issues. Please try again later or contact your administrator."
388
+
389
+ # Print proper log and raise gateway error
390
+ log.error(log_message, url=str(response.url), status_code=response.status_code, body=response_body)
391
+ error_msg = extract_auth_error_message(response_body, base_msg)
392
+ raise GatewayError(message=error_msg)
infrahub/cli/db.py CHANGED
@@ -54,6 +54,7 @@ from infrahub.log import get_logger
54
54
 
55
55
  from .constants import ERROR_BADGE, FAILED_BADGE, SUCCESS_BADGE
56
56
  from .db_commands.check_inheritance import check_inheritance
57
+ from .db_commands.clean_duplicate_schema_fields import clean_duplicate_schema_fields
57
58
  from .patch import patch_app
58
59
 
59
60
 
@@ -200,6 +201,29 @@ async def check_inheritance_cmd(
200
201
  await dbdriver.close()
201
202
 
202
203
 
204
+ @app.command(name="check-duplicate-schema-fields")
205
+ async def check_duplicate_schema_fields_cmd(
206
+ ctx: typer.Context,
207
+ fix: bool = typer.Option(False, help="Fix the duplicate schema fields on the default branch."),
208
+ config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"),
209
+ ) -> None:
210
+ """Check for any duplicate schema attributes or relationships on the default branch"""
211
+ logging.getLogger("infrahub").setLevel(logging.WARNING)
212
+ logging.getLogger("neo4j").setLevel(logging.ERROR)
213
+ logging.getLogger("prefect").setLevel(logging.ERROR)
214
+
215
+ config.load_and_exit(config_file_name=config_file)
216
+
217
+ context: CliContext = ctx.obj
218
+ dbdriver = await context.init_db(retry=1)
219
+
220
+ success = await clean_duplicate_schema_fields(db=dbdriver, fix=fix)
221
+ if not success:
222
+ raise typer.Exit(code=1)
223
+
224
+ await dbdriver.close()
225
+
226
+
203
227
  @app.command(name="update-core-schema")
204
228
  async def update_core_schema_cmd(
205
229
  ctx: typer.Context,
@@ -0,0 +1,212 @@
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+ from typing import Any
4
+
5
+ from rich import print as rprint
6
+ from rich.console import Console
7
+ from rich.table import Table
8
+
9
+ from infrahub.cli.constants import FAILED_BADGE, SUCCESS_BADGE
10
+ from infrahub.core.query import Query, QueryType
11
+ from infrahub.database import InfrahubDatabase
12
+
13
+
14
+ class SchemaFieldType(str, Enum):
15
+ ATTRIBUTE = "attribute"
16
+ RELATIONSHIP = "relationship"
17
+
18
+
19
+ @dataclass
20
+ class SchemaFieldDetails:
21
+ schema_kind: str
22
+ schema_uuid: str
23
+ field_type: SchemaFieldType
24
+ field_name: str
25
+
26
+
27
+ class DuplicateSchemaFields(Query):
28
+ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002
29
+ query = """
30
+ MATCH (root:Root)
31
+ LIMIT 1
32
+ WITH root.default_branch AS default_branch
33
+ MATCH (field:SchemaAttribute|SchemaRelationship)
34
+ CALL (default_branch, field) {
35
+ MATCH (field)-[is_part_of:IS_PART_OF]->(:Root)
36
+ WHERE is_part_of.branch = default_branch
37
+ ORDER BY is_part_of.from DESC
38
+ RETURN is_part_of
39
+ LIMIT 1
40
+ }
41
+ WITH default_branch, field, CASE
42
+ WHEN is_part_of.status = "active" AND is_part_of.to IS NULL THEN is_part_of.from
43
+ ELSE NULL
44
+ END AS active_from
45
+ WHERE active_from IS NOT NULL
46
+ WITH default_branch, field, active_from, "SchemaAttribute" IN labels(field) AS is_attribute
47
+ CALL (field, default_branch) {
48
+ MATCH (field)-[r1:HAS_ATTRIBUTE]->(:Attribute {name: "name"})-[r2:HAS_VALUE]->(name_value:AttributeValue)
49
+ WHERE r1.branch = default_branch AND r2.branch = default_branch
50
+ AND r1.status = "active" AND r2.status = "active"
51
+ AND r1.to IS NULL AND r2.to IS NULL
52
+ ORDER BY r1.from DESC, r1.status ASC, r2.from DESC, r2.status ASC
53
+ LIMIT 1
54
+ RETURN name_value.value AS field_name
55
+ }
56
+ CALL (field, default_branch) {
57
+ MATCH (field)-[r1:IS_RELATED]-(rel:Relationship)-[r2:IS_RELATED]-(peer:SchemaNode|SchemaGeneric)
58
+ WHERE rel.name IN ["schema__node__relationships", "schema__node__attributes"]
59
+ AND r1.branch = default_branch AND r2.branch = default_branch
60
+ AND r1.status = "active" AND r2.status = "active"
61
+ AND r1.to IS NULL AND r2.to IS NULL
62
+ ORDER BY r1.from DESC, r1.status ASC, r2.from DESC, r2.status ASC
63
+ LIMIT 1
64
+ RETURN peer AS schema_vertex
65
+ }
66
+ WITH default_branch, field, field_name, is_attribute, active_from, schema_vertex
67
+ ORDER BY active_from DESC
68
+ WITH default_branch, field_name, is_attribute, schema_vertex, collect(field) AS fields_reverse_chron
69
+ WHERE size(fields_reverse_chron) > 1
70
+ """
71
+ self.add_to_query(query)
72
+
73
+
74
+ class GetDuplicateSchemaFields(DuplicateSchemaFields):
75
+ """
76
+ Get the kind, field type, and field name for any duplicated attributes or relationships on a given schema
77
+ on the default branch
78
+ """
79
+
80
+ name = "get_duplicate_schema_fields"
81
+ type = QueryType.READ
82
+ insert_return = False
83
+
84
+ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None:
85
+ await super().query_init(db=db, **kwargs)
86
+ query = """
87
+ CALL (schema_vertex, default_branch) {
88
+ MATCH (schema_vertex)-[r1:HAS_ATTRIBUTE]->(:Attribute {name: "namespace"})-[r2:HAS_VALUE]->(name_value:AttributeValue)
89
+ WHERE r1.branch = default_branch AND r2.branch = default_branch
90
+ ORDER BY r1.from DESC, r1.status ASC, r2.from DESC, r2.status ASC
91
+ LIMIT 1
92
+ RETURN name_value.value AS schema_namespace
93
+ }
94
+ CALL (schema_vertex, default_branch) {
95
+ MATCH (schema_vertex)-[r1:HAS_ATTRIBUTE]->(:Attribute {name: "name"})-[r2:HAS_VALUE]->(name_value:AttributeValue)
96
+ WHERE r1.branch = default_branch AND r2.branch = default_branch
97
+ ORDER BY r1.from DESC, r1.status ASC, r2.from DESC, r2.status ASC
98
+ LIMIT 1
99
+ RETURN name_value.value AS schema_name
100
+ }
101
+ RETURN schema_namespace + schema_name AS schema_kind, schema_vertex.uuid AS schema_uuid, field_name, is_attribute
102
+ ORDER BY schema_kind ASC, is_attribute DESC, field_name ASC
103
+ """
104
+ self.return_labels = ["schema_kind", "schema_uuid", "field_name", "is_attribute"]
105
+ self.add_to_query(query)
106
+
107
+ def get_schema_field_details(self) -> list[SchemaFieldDetails]:
108
+ schema_field_details: list[SchemaFieldDetails] = []
109
+ for result in self.results:
110
+ schema_kind = result.get_as_type(label="schema_kind", return_type=str)
111
+ schema_uuid = result.get_as_type(label="schema_uuid", return_type=str)
112
+ field_name = result.get_as_type(label="field_name", return_type=str)
113
+ is_attribute = result.get_as_type(label="is_attribute", return_type=bool)
114
+ schema_field_details.append(
115
+ SchemaFieldDetails(
116
+ schema_kind=schema_kind,
117
+ schema_uuid=schema_uuid,
118
+ field_name=field_name,
119
+ field_type=SchemaFieldType.ATTRIBUTE if is_attribute else SchemaFieldType.RELATIONSHIP,
120
+ )
121
+ )
122
+ return schema_field_details
123
+
124
+
125
+ class FixDuplicateSchemaFields(DuplicateSchemaFields):
126
+ """
127
+ Fix the duplicate schema fields by hard deleting the earlier duplicate(s)
128
+ """
129
+
130
+ name = "fix_duplicate_schema_fields"
131
+ type = QueryType.WRITE
132
+ insert_return = False
133
+
134
+ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None:
135
+ await super().query_init(db=db, **kwargs)
136
+ query = """
137
+ WITH default_branch, tail(fields_reverse_chron) AS fields_to_delete
138
+ UNWIND fields_to_delete AS field_to_delete
139
+ CALL (field_to_delete, default_branch) {
140
+ MATCH (field_to_delete)-[r:IS_PART_OF {branch: default_branch}]-()
141
+ DELETE r
142
+ WITH field_to_delete
143
+ MATCH (field_to_delete)-[:IS_RELATED {branch: default_branch}]-(rel:Relationship)
144
+ WITH DISTINCT field_to_delete, rel
145
+ MATCH (rel)-[r {branch: default_branch}]-()
146
+ DELETE r
147
+ WITH field_to_delete, rel
148
+ OPTIONAL MATCH (rel)
149
+ WHERE NOT exists((rel)--())
150
+ DELETE rel
151
+ WITH DISTINCT field_to_delete
152
+ MATCH (field_to_delete)-[:HAS_ATTRIBUTE {branch: default_branch}]->(attr:Attribute)
153
+ MATCH (attr)-[r {branch: default_branch}]-()
154
+ DELETE r
155
+ WITH field_to_delete, attr
156
+ OPTIONAL MATCH (attr)
157
+ WHERE NOT exists((attr)--())
158
+ DELETE attr
159
+ WITH DISTINCT field_to_delete
160
+ OPTIONAL MATCH (field_to_delete)
161
+ WHERE NOT exists((field_to_delete)--())
162
+ DELETE field_to_delete
163
+ }
164
+ """
165
+ self.add_to_query(query)
166
+
167
+
168
+ def display_duplicate_schema_fields(duplicate_schema_fields: list[SchemaFieldDetails]) -> None:
169
+ console = Console()
170
+
171
+ table = Table(title="Duplicate Schema Fields on Default Branch")
172
+
173
+ table.add_column("Schema Kind")
174
+ table.add_column("Schema UUID")
175
+ table.add_column("Field Name")
176
+ table.add_column("Field Type")
177
+
178
+ for duplicate_schema_field in duplicate_schema_fields:
179
+ table.add_row(
180
+ duplicate_schema_field.schema_kind,
181
+ duplicate_schema_field.schema_uuid,
182
+ duplicate_schema_field.field_name,
183
+ duplicate_schema_field.field_type.value,
184
+ )
185
+
186
+ console.print(table)
187
+
188
+
189
+ async def clean_duplicate_schema_fields(db: InfrahubDatabase, fix: bool = False) -> bool:
190
+ """
191
+ Identify any attributes or relationships that are duplicated in a schema on the default branch
192
+ If fix is True, runs cypher queries to hard delete the earlier duplicate
193
+ """
194
+
195
+ duplicate_schema_fields_query = await GetDuplicateSchemaFields.init(db=db)
196
+ await duplicate_schema_fields_query.execute(db=db)
197
+ duplicate_schema_fields = duplicate_schema_fields_query.get_schema_field_details()
198
+
199
+ if not duplicate_schema_fields:
200
+ rprint(f"{SUCCESS_BADGE} No duplicate schema fields found")
201
+ return True
202
+
203
+ display_duplicate_schema_fields(duplicate_schema_fields)
204
+
205
+ if not fix:
206
+ rprint(f"{FAILED_BADGE} Use the --fix flag to fix the duplicate schema fields")
207
+ return False
208
+
209
+ fix_duplicate_schema_fields_query = await FixDuplicateSchemaFields.init(db=db)
210
+ await fix_duplicate_schema_fields_query.execute(db=db)
211
+ rprint(f"{SUCCESS_BADGE} Duplicate schema fields deleted from the default branch")
212
+ return True
@@ -104,7 +104,7 @@ async def process_transform(
104
104
  ) # type: ignore[misc]
105
105
 
106
106
  data = await client.query_gql_query(
107
- name=transform.query.peer.name.value,
107
+ name=transform.query.id,
108
108
  branch_name=branch_name,
109
109
  variables={"id": object_id},
110
110
  update_group=True,
@@ -560,7 +560,7 @@ class RelationshipChangelogGetter:
560
560
 
561
561
  for peer in relationship.peers:
562
562
  if peer.peer_status == DiffAction.ADDED:
563
- peer_schema = schema_branch.get(name=peer.peer_kind)
563
+ peer_schema = schema_branch.get(name=peer.peer_kind, duplicate=False)
564
564
  secondaries.extend(
565
565
  self._process_added_peers(
566
566
  peer_id=peer.peer_id,
@@ -572,7 +572,7 @@ class RelationshipChangelogGetter:
572
572
  )
573
573
 
574
574
  elif peer.peer_status == DiffAction.REMOVED:
575
- peer_schema = schema_branch.get(name=peer.peer_kind)
575
+ peer_schema = schema_branch.get(name=peer.peer_kind, duplicate=False)
576
576
  secondaries.extend(
577
577
  self._process_removed_peers(
578
578
  peer_id=peer.peer_id,
@@ -44,7 +44,7 @@ CALL (source_artifact) {
44
44
  MATCH (source_artifact)-[r:IS_PART_OF]->(:Root)
45
45
  WHERE %(source_branch_filter)s
46
46
  RETURN r AS root_rel
47
- ORDER BY r.branch_level DESC, r.from DESC
47
+ ORDER BY r.branch_level DESC, r.from DESC, r.status ASC
48
48
  LIMIT 1
49
49
  }
50
50
  WITH source_artifact, root_rel
@@ -61,7 +61,7 @@ CALL (source_artifact) {
61
61
  target_node,
62
62
  (rrel1.status = "active" AND rrel2.status = "active") AS target_is_active,
63
63
  $source_branch_name IN [rrel1.branch, rrel2.branch] AS target_on_source_branch
64
- ORDER BY rrel1.branch_level DESC, rrel1.branch_level DESC, rrel1.from DESC, rrel2.from DESC
64
+ ORDER BY rrel1.branch_level DESC, rrel2.branch_level DESC, rrel1.from DESC, rrel2.from DESC, rrel1.status ASC, rrel2.status ASC
65
65
  LIMIT 1
66
66
  }
67
67
  // -----------------------
@@ -75,7 +75,7 @@ CALL (source_artifact) {
75
75
  definition_node,
76
76
  (rrel1.status = "active" AND rrel2.status = "active") AS definition_is_active,
77
77
  $source_branch_name IN [rrel1.branch, rrel2.branch] AS definition_on_source_branch
78
- ORDER BY rrel1.branch_level DESC, rrel1.branch_level DESC, rrel1.from DESC, rrel2.from DESC
78
+ ORDER BY rrel1.branch_level DESC, rrel2.branch_level DESC, rrel1.from DESC, rrel2.from DESC, rrel1.status ASC, rrel2.status ASC
79
79
  LIMIT 1
80
80
  }
81
81
  // -----------------------
@@ -89,7 +89,8 @@ CALL (source_artifact) {
89
89
  attr_val.value AS checksum,
90
90
  (attr_rel.status = "active" AND value_rel.status = "active") AS checksum_is_active,
91
91
  $source_branch_name IN [attr_rel.branch, value_rel.branch] AS checksum_on_source_branch
92
- ORDER BY value_rel.branch_level DESC, attr_rel.branch_level DESC, value_rel.from DESC, attr_rel.from DESC
92
+ ORDER BY value_rel.branch_level DESC, attr_rel.branch_level DESC, value_rel.from DESC, attr_rel.from DESC,
93
+ value_rel.status ASC, attr_rel.status ASC
93
94
  LIMIT 1
94
95
  }
95
96
  // -----------------------
@@ -103,7 +104,8 @@ CALL (source_artifact) {
103
104
  attr_val.value AS storage_id,
104
105
  (attr_rel.status = "active" AND value_rel.status = "active") AS storage_id_is_active,
105
106
  $source_branch_name IN [attr_rel.branch, value_rel.branch] AS storage_id_on_source_branch
106
- ORDER BY value_rel.branch_level DESC, attr_rel.branch_level DESC, value_rel.from DESC, attr_rel.from DESC
107
+ ORDER BY value_rel.branch_level DESC, attr_rel.branch_level DESC, value_rel.from DESC, attr_rel.from DESC,
108
+ value_rel.status ASC, attr_rel.status ASC
107
109
  LIMIT 1
108
110
  }
109
111
  WITH target_node, target_is_active, target_on_source_branch,
@@ -146,8 +148,9 @@ CALL (target_node, definition_node){
146
148
  )
147
149
  RETURN
148
150
  target_artifact,
149
- (trel1.status = "active" AND trel2.status = "active" AND drel1.status = "active" AND drel1.status = "active") AS artifact_is_active
150
- ORDER BY trel1.from DESC, trel2.from DESC, drel1.from DESC, drel2.from DESC
151
+ (trel1.status = "active" AND trel2.status = "active" AND drel1.status = "active" AND drel2.status = "active") AS artifact_is_active
152
+ ORDER BY trel1.from DESC, trel2.from DESC, drel1.from DESC, drel2.from DESC,
153
+ trel1.status ASC, trel2.status ASC, drel1.status ASC, drel2.status ASC
151
154
  LIMIT 1
152
155
  }
153
156
  WITH CASE
@@ -163,7 +166,7 @@ CALL (target_node, definition_node){
163
166
  AND attr_rel.branch = $target_branch_name
164
167
  AND value_rel.branch = $target_branch_name
165
168
  RETURN attr_val.value AS checksum, (attr_rel.status = "active" AND value_rel.status = "active") AS checksum_is_active
166
- ORDER BY value_rel.from DESC, attr_rel.from DESC
169
+ ORDER BY value_rel.from DESC, attr_rel.from DESC, value_rel.status ASC, attr_rel.status ASC
167
170
  LIMIT 1
168
171
  }
169
172
  // -----------------------
@@ -175,7 +178,7 @@ CALL (target_node, definition_node){
175
178
  AND attr_rel.branch = $target_branch_name
176
179
  AND value_rel.branch = $target_branch_name
177
180
  RETURN attr_val.value AS storage_id, (attr_rel.status = "active" AND value_rel.status = "active") AS storage_id_is_active
178
- ORDER BY value_rel.from DESC, attr_rel.from DESC
181
+ ORDER BY value_rel.from DESC, attr_rel.from DESC, value_rel.status ASC, attr_rel.status ASC
179
182
  LIMIT 1
180
183
  }
181
184
  RETURN target_artifact,