infrahub-server 1.1.2__py3-none-any.whl → 1.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. infrahub/api/__init__.py +13 -5
  2. infrahub/api/artifact.py +2 -1
  3. infrahub/api/auth.py +7 -1
  4. infrahub/api/diff/diff.py +13 -7
  5. infrahub/api/file.py +3 -3
  6. infrahub/api/internal.py +19 -6
  7. infrahub/api/oauth2.py +22 -7
  8. infrahub/api/oidc.py +23 -7
  9. infrahub/api/schema.py +39 -20
  10. infrahub/api/storage.py +8 -8
  11. infrahub/api/transformation.py +3 -2
  12. infrahub/auth.py +1 -24
  13. infrahub/cli/__init__.py +1 -1
  14. infrahub/cli/context.py +5 -8
  15. infrahub/cli/db.py +6 -6
  16. infrahub/cli/git_agent.py +1 -1
  17. infrahub/config.py +1 -2
  18. infrahub/core/attribute.py +22 -0
  19. infrahub/core/constants/__init__.py +5 -0
  20. infrahub/core/diff/calculator.py +14 -0
  21. infrahub/core/diff/combiner.py +5 -2
  22. infrahub/core/diff/conflicts_enricher.py +2 -2
  23. infrahub/core/diff/coordinator.py +11 -1
  24. infrahub/core/diff/enricher/cardinality_one.py +3 -3
  25. infrahub/core/diff/enricher/hierarchy.py +2 -1
  26. infrahub/core/diff/merger/merger.py +10 -0
  27. infrahub/core/diff/merger/serializer.py +5 -29
  28. infrahub/core/diff/model/path.py +3 -1
  29. infrahub/core/diff/query_parser.py +26 -11
  30. infrahub/core/diff/repository/repository.py +4 -4
  31. infrahub/core/ipam/utilization.py +6 -1
  32. infrahub/core/merge.py +5 -0
  33. infrahub/core/migrations/query/attribute_add.py +5 -5
  34. infrahub/core/query/diff.py +32 -19
  35. infrahub/core/query/ipam.py +30 -22
  36. infrahub/core/query/node.py +4 -0
  37. infrahub/core/schema/__init__.py +5 -0
  38. infrahub/core/validators/attribute/kind.py +1 -1
  39. infrahub/core/validators/models/violation.py +1 -14
  40. infrahub/core/validators/tasks.py +4 -1
  41. infrahub/dependencies/builder/constraint/schema/aggregated.py +2 -0
  42. infrahub/dependencies/builder/constraint/schema/attribute_kind.py +8 -0
  43. infrahub/graphql/api/endpoints.py +12 -3
  44. infrahub/graphql/mutations/account.py +4 -4
  45. infrahub/graphql/mutations/main.py +5 -16
  46. infrahub/graphql/mutations/resource_manager.py +3 -3
  47. infrahub/graphql/queries/resource_manager.py +21 -10
  48. infrahub/task_manager/task.py +5 -1
  49. infrahub_sdk/analyzer.py +1 -1
  50. infrahub_sdk/checks.py +4 -4
  51. infrahub_sdk/client.py +26 -16
  52. infrahub_sdk/generator.py +3 -3
  53. infrahub_sdk/node.py +2 -2
  54. infrahub_sdk/pytest_plugin/items/base.py +0 -5
  55. infrahub_sdk/repository.py +33 -0
  56. infrahub_sdk/testing/repository.py +14 -8
  57. infrahub_sdk/transforms.py +3 -3
  58. infrahub_sdk/utils.py +8 -3
  59. {infrahub_server-1.1.2.dist-info → infrahub_server-1.1.4.dist-info}/METADATA +2 -1
  60. {infrahub_server-1.1.2.dist-info → infrahub_server-1.1.4.dist-info}/RECORD +63 -62
  61. infrahub_sdk/task_report.py +0 -208
  62. {infrahub_server-1.1.2.dist-info → infrahub_server-1.1.4.dist-info}/LICENSE.txt +0 -0
  63. {infrahub_server-1.1.2.dist-info → infrahub_server-1.1.4.dist-info}/WHEEL +0 -0
  64. {infrahub_server-1.1.2.dist-info → infrahub_server-1.1.4.dist-info}/entry_points.txt +0 -0
infrahub/api/__init__.py CHANGED
@@ -1,11 +1,13 @@
1
- from typing import NoReturn
1
+ from __future__ import annotations
2
2
 
3
- from fastapi import APIRouter
3
+ from typing import TYPE_CHECKING, NoReturn
4
+
5
+ from fastapi import APIRouter, Depends
4
6
  from fastapi.openapi.docs import (
5
7
  get_redoc_html,
6
8
  get_swagger_ui_html,
7
9
  )
8
- from starlette.responses import HTMLResponse
10
+ from starlette.responses import HTMLResponse # noqa: TC002
9
11
 
10
12
  from infrahub.api import (
11
13
  artifact,
@@ -21,8 +23,12 @@ from infrahub.api import (
21
23
  storage,
22
24
  transformation,
23
25
  )
26
+ from infrahub.api.dependencies import get_current_user
24
27
  from infrahub.exceptions import ResourceNotFoundError
25
28
 
29
+ if TYPE_CHECKING:
30
+ from infrahub.auth import AccountSession
31
+
26
32
  router = APIRouter(prefix="/api")
27
33
 
28
34
  router.include_router(artifact.router)
@@ -40,7 +46,9 @@ router.include_router(transformation.router)
40
46
 
41
47
 
42
48
  @router.get("/docs", include_in_schema=False)
43
- async def custom_swagger_ui_html() -> HTMLResponse:
49
+ async def custom_swagger_ui_html(
50
+ _: AccountSession = Depends(get_current_user),
51
+ ) -> HTMLResponse:
44
52
  return get_swagger_ui_html(
45
53
  openapi_url="/api/openapi.json",
46
54
  title="Infrahub - Swagger UI",
@@ -50,7 +58,7 @@ async def custom_swagger_ui_html() -> HTMLResponse:
50
58
 
51
59
 
52
60
  @router.get("/redoc", include_in_schema=False)
53
- async def redoc_html() -> HTMLResponse:
61
+ async def redoc_html(_: AccountSession = Depends(get_current_user)) -> HTMLResponse:
54
62
  return get_redoc_html(
55
63
  openapi_url="/api/openapi.json",
56
64
  title="Infrahub - ReDoc",
infrahub/api/artifact.py CHANGED
@@ -18,6 +18,7 @@ from infrahub.permissions.constants import PermissionDecisionFlag
18
18
  from infrahub.workflows.catalogue import REQUEST_ARTIFACT_DEFINITION_GENERATE
19
19
 
20
20
  if TYPE_CHECKING:
21
+ from infrahub.auth import AccountSession
21
22
  from infrahub.permissions import PermissionManager
22
23
 
23
24
  log = get_logger()
@@ -37,7 +38,7 @@ async def get_artifact(
37
38
  artifact_id: str,
38
39
  db: InfrahubDatabase = Depends(get_db),
39
40
  branch_params: BranchParams = Depends(get_branch_params),
40
- _: str = Depends(get_current_user),
41
+ _: AccountSession = Depends(get_current_user),
41
42
  ) -> Response:
42
43
  artifact = await registry.manager.get_one(db=db, id=artifact_id, branch=branch_params.branch, at=branch_params.at)
43
44
  if not artifact:
infrahub/api/auth.py CHANGED
@@ -1,3 +1,7 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
1
5
  from fastapi import APIRouter, Depends, Response
2
6
 
3
7
  from infrahub import config, models
@@ -8,7 +12,9 @@ from infrahub.auth import (
8
12
  create_fresh_access_token,
9
13
  invalidate_refresh_token,
10
14
  )
11
- from infrahub.database import InfrahubDatabase
15
+
16
+ if TYPE_CHECKING:
17
+ from infrahub.database import InfrahubDatabase
12
18
 
13
19
  router = APIRouter(prefix="/auth")
14
20
 
infrahub/api/diff/diff.py CHANGED
@@ -1,13 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from collections import defaultdict
4
- from typing import TYPE_CHECKING, Optional
4
+ from typing import TYPE_CHECKING
5
5
 
6
6
  from fastapi import APIRouter, Depends, Request
7
7
 
8
8
  from infrahub.api.dependencies import get_branch_dep, get_current_user, get_db
9
9
  from infrahub.core import registry
10
- from infrahub.core.branch import Branch # noqa: TC001
11
10
  from infrahub.core.diff.artifacts.calculator import ArtifactDiffCalculator
12
11
  from infrahub.core.diff.branch_differ import BranchDiffer
13
12
  from infrahub.core.diff.model.diff import (
@@ -15,9 +14,11 @@ from infrahub.core.diff.model.diff import (
15
14
  BranchDiffFile,
16
15
  BranchDiffRepository,
17
16
  )
18
- from infrahub.database import InfrahubDatabase # noqa: TC001
19
17
 
20
18
  if TYPE_CHECKING:
19
+ from infrahub.auth import AccountSession
20
+ from infrahub.core.branch import Branch
21
+ from infrahub.database import InfrahubDatabase
21
22
  from infrahub.services import InfrahubServices
22
23
 
23
24
 
@@ -29,17 +30,22 @@ async def get_diff_files(
29
30
  request: Request,
30
31
  db: InfrahubDatabase = Depends(get_db),
31
32
  branch: Branch = Depends(get_branch_dep),
32
- time_from: Optional[str] = None,
33
- time_to: Optional[str] = None,
33
+ time_from: str | None = None,
34
+ time_to: str | None = None,
34
35
  branch_only: bool = True,
35
- _: str = Depends(get_current_user),
36
+ _: AccountSession = Depends(get_current_user),
36
37
  ) -> dict[str, dict[str, BranchDiffRepository]]:
37
38
  response: dict[str, dict[str, BranchDiffRepository]] = defaultdict(dict)
38
39
  service: InfrahubServices = request.app.state.service
39
40
 
40
41
  # Query the Diff for all files and repository from the database
41
42
  diff = await BranchDiffer.init(
42
- db=db, branch=branch, diff_from=time_from, diff_to=time_to, branch_only=branch_only, service=service
43
+ db=db,
44
+ branch=branch,
45
+ diff_from=time_from,
46
+ diff_to=time_to,
47
+ branch_only=branch_only,
48
+ service=service,
43
49
  )
44
50
  diff_files = await diff.get_files()
45
51
 
infrahub/api/file.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, Optional, Union
3
+ from typing import TYPE_CHECKING
4
4
 
5
5
  from fastapi import APIRouter, Depends, Request
6
6
  from starlette.responses import PlainTextResponse
@@ -27,13 +27,13 @@ async def get_file(
27
27
  file_path: str,
28
28
  branch_params: BranchParams = Depends(get_branch_params),
29
29
  db: InfrahubDatabase = Depends(get_db),
30
- commit: Optional[str] = None,
30
+ commit: str | None = None,
31
31
  _: str = Depends(get_current_user),
32
32
  ) -> PlainTextResponse:
33
33
  """Retrieve a file from a git repository."""
34
34
  service: InfrahubServices = request.app.state.service
35
35
 
36
- repo: Union[CoreRepository, CoreReadOnlyRepository] = await NodeManager.get_one_by_id_or_default_filter(
36
+ repo: CoreRepository | CoreReadOnlyRepository = await NodeManager.get_one_by_id_or_default_filter(
37
37
  db=db,
38
38
  id=repository_id,
39
39
  kind=InfrahubKind.GENERICREPOSITORY,
infrahub/api/internal.py CHANGED
@@ -1,16 +1,27 @@
1
+ from __future__ import annotations
2
+
1
3
  import re
2
- from typing import Optional
4
+ from typing import TYPE_CHECKING
3
5
 
4
6
  import ujson
5
- from fastapi import APIRouter, Request
7
+ from fastapi import APIRouter, Depends, Request
6
8
  from lunr.index import Index
7
9
  from pydantic import BaseModel
8
10
 
9
11
  from infrahub import config
10
- from infrahub.config import AnalyticsSettings, ExperimentalFeaturesSettings, LoggingSettings, MainSettings
12
+ from infrahub.api.dependencies import get_current_user
13
+ from infrahub.config import ( # noqa: TC001
14
+ AnalyticsSettings,
15
+ ExperimentalFeaturesSettings,
16
+ LoggingSettings,
17
+ MainSettings,
18
+ )
11
19
  from infrahub.core import registry
12
20
  from infrahub.exceptions import NodeNotFoundError
13
21
 
22
+ if TYPE_CHECKING:
23
+ from infrahub.auth import AccountSession
24
+
14
25
  router = APIRouter()
15
26
 
16
27
 
@@ -39,7 +50,7 @@ async def get_config() -> ConfigAPI:
39
50
 
40
51
 
41
52
  @router.get("/info")
42
- async def get_info(request: Request) -> InfoAPI:
53
+ async def get_info(request: Request, _: AccountSession = Depends(get_current_user)) -> InfoAPI:
43
54
  return InfoAPI(deployment_id=str(registry.id), version=request.app.version)
44
55
 
45
56
 
@@ -47,7 +58,7 @@ class SearchDocs:
47
58
  def __init__(self) -> None:
48
59
  self._title_documents: list[dict] = []
49
60
  self._heading_documents: list[dict] = []
50
- self._heading_index: Optional[Index] = None
61
+ self._heading_index: Index | None = None
51
62
 
52
63
  def _load_json(self) -> None:
53
64
  """
@@ -142,7 +153,9 @@ class SearchResultAPI(BaseModel):
142
153
 
143
154
 
144
155
  @router.get("/search/docs", include_in_schema=False)
145
- async def search_docs(query: str, limit: Optional[int] = None) -> list[SearchResultAPI]:
156
+ async def search_docs(
157
+ query: str, limit: int | None = None, _: AccountSession = Depends(get_current_user)
158
+ ) -> list[SearchResultAPI]:
146
159
  smart_query = smart_queries(query)
147
160
  search_results = search_docs_loader.heading_index.search(smart_query)
148
161
  heading_results = [
infrahub/api/oauth2.py CHANGED
@@ -3,9 +3,11 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING
4
4
  from urllib.parse import urljoin
5
5
 
6
+ import ujson
6
7
  from authlib.integrations.httpx_client import AsyncOAuth2Client
7
8
  from fastapi import APIRouter, Depends, Request, Response
8
9
  from fastapi.responses import JSONResponse, RedirectResponse
10
+ from opentelemetry import trace
9
11
 
10
12
  from infrahub import config, models
11
13
  from infrahub.api.dependencies import get_db
@@ -20,6 +22,8 @@ if TYPE_CHECKING:
20
22
  from infrahub.database import InfrahubDatabase
21
23
  from infrahub.services import InfrahubServices
22
24
 
25
+ # pylint: disable=R0801
26
+
23
27
  log = get_logger()
24
28
  router = APIRouter(prefix="/oauth2")
25
29
 
@@ -33,11 +37,16 @@ def _get_redirect_url(request: Request, provider_name: str) -> str:
33
37
  @router.get("/{provider_name:str}/authorize")
34
38
  async def authorize(request: Request, provider_name: str, final_url: str | None = None) -> Response:
35
39
  provider = config.SETTINGS.security.get_oauth2_provider(provider=provider_name)
36
- client = AsyncOAuth2Client(
37
- client_id=provider.client_id,
38
- client_secret=provider.client_secret,
39
- scope=provider.scopes,
40
- )
40
+
41
+ with trace.get_tracer(__name__).start_as_current_span("sso_oauth2_client_configuration") as span:
42
+ span.set_attribute("provider_name", provider_name)
43
+ span.set_attribute("scopes", provider.scopes)
44
+
45
+ client = AsyncOAuth2Client(
46
+ client_id=provider.client_id,
47
+ client_secret=provider.client_secret,
48
+ scope=provider.scopes,
49
+ )
41
50
 
42
51
  redirect_uri = _get_redirect_url(request=request, provider_name=provider_name)
43
52
  final_url = final_url or config.SETTINGS.main.public_url or str(request.base_url)
@@ -88,7 +97,10 @@ async def token(
88
97
 
89
98
  token_response = await service.http.post(provider.token_url, data=token_data)
90
99
  _validate_response(response=token_response)
91
- payload = token_response.json()
100
+
101
+ with trace.get_tracer(__name__).start_as_current_span("sso_token_request") as span:
102
+ span.set_attribute("token_request_data", ujson.dumps(token_response.json()))
103
+ payload = token_response.json()
92
104
 
93
105
  headers = {"Authorization": f"{payload.get('token_type')} {payload.get('access_token')}"}
94
106
  if provider.userinfo_method == config.UserInfoMethod.GET:
@@ -102,7 +114,10 @@ async def token(
102
114
  if not sso_groups and config.SETTINGS.security.sso_user_default_group:
103
115
  sso_groups = [config.SETTINGS.security.sso_user_default_group]
104
116
 
105
- user_token = await signin_sso_account(db=db, account_name=user_info["name"], sso_groups=sso_groups)
117
+ with trace.get_tracer(__name__).start_as_current_span("signin_sso_account") as span:
118
+ span.set_attribute("account_name", ujson.dumps(userinfo_response.json()))
119
+ span.set_attribute("sso_groups", sso_groups)
120
+ user_token = await signin_sso_account(db=db, account_name=user_info["name"], sso_groups=sso_groups)
106
121
 
107
122
  response.set_cookie(
108
123
  "access_token", user_token.access_token, httponly=True, max_age=config.SETTINGS.security.access_token_lifetime
infrahub/api/oidc.py CHANGED
@@ -3,9 +3,11 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING
4
4
  from urllib.parse import urljoin
5
5
 
6
+ import ujson
6
7
  from authlib.integrations.httpx_client import AsyncOAuth2Client
7
8
  from fastapi import APIRouter, Depends, Request, Response
8
9
  from fastapi.responses import JSONResponse, RedirectResponse
10
+ from opentelemetry import trace
9
11
  from pydantic import BaseModel, HttpUrl
10
12
 
11
13
  from infrahub import config, models
@@ -21,6 +23,8 @@ if TYPE_CHECKING:
21
23
  from infrahub.database import InfrahubDatabase
22
24
  from infrahub.services import InfrahubServices
23
25
 
26
+ # pylint: disable=R0801
27
+
24
28
  log = get_logger()
25
29
  router = APIRouter(prefix="/oidc")
26
30
 
@@ -68,11 +72,16 @@ async def authorize(request: Request, provider_name: str, final_url: str | None
68
72
  _validate_response(response=response)
69
73
  oidc_config = OIDCDiscoveryConfig(**response.json())
70
74
 
71
- client = AsyncOAuth2Client(
72
- client_id=provider.client_id,
73
- client_secret=provider.client_secret,
74
- scope=provider.scopes,
75
- )
75
+ with trace.get_tracer(__name__).start_as_current_span("sso_oauth2_client_configuration") as span:
76
+ span.set_attribute("provider_name", provider_name)
77
+ span.set_attribute("scopes", provider.scopes)
78
+ span.set_attribute("discovery_url", provider.discovery_url)
79
+
80
+ client = AsyncOAuth2Client(
81
+ client_id=provider.client_id,
82
+ client_secret=provider.client_secret,
83
+ scope=provider.scopes,
84
+ )
76
85
 
77
86
  redirect_uri = _get_redirect_url(request=request, provider_name=provider_name)
78
87
  final_url = final_url or config.SETTINGS.main.public_url or str(request.base_url)
@@ -126,7 +135,10 @@ async def token(
126
135
 
127
136
  token_response = await service.http.post(str(oidc_config.token_endpoint), data=token_data)
128
137
  _validate_response(response=token_response)
129
- payload = token_response.json()
138
+
139
+ with trace.get_tracer(__name__).start_as_current_span("sso_token_request") as span:
140
+ span.set_attribute("token_request_data", ujson.dumps(token_response.json()))
141
+ payload = token_response.json()
130
142
 
131
143
  headers = {"Authorization": f"{payload.get('token_type')} {payload.get('access_token')}"}
132
144
 
@@ -138,10 +150,14 @@ async def token(
138
150
  _validate_response(response=userinfo_response)
139
151
  user_info = userinfo_response.json()
140
152
  sso_groups = user_info.get("groups", [])
153
+
141
154
  if not sso_groups and config.SETTINGS.security.sso_user_default_group:
142
155
  sso_groups = [config.SETTINGS.security.sso_user_default_group]
143
156
 
144
- user_token = await signin_sso_account(db=db, account_name=user_info["name"], sso_groups=sso_groups)
157
+ with trace.get_tracer(__name__).start_as_current_span("signin_sso_account") as span:
158
+ span.set_attribute("account_name", ujson.dumps(userinfo_response.json()))
159
+ span.set_attribute("sso_groups", sso_groups)
160
+ user_token = await signin_sso_account(db=db, account_name=user_info["name"], sso_groups=sso_groups)
145
161
 
146
162
  response.set_cookie(
147
163
  "access_token", user_token.access_token, httponly=True, max_age=config.SETTINGS.security.access_token_lifetime
infrahub/api/schema.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, Any, Optional, Union
3
+ from typing import TYPE_CHECKING, Any, Sequence
4
4
 
5
5
  from fastapi import APIRouter, Depends, Query, Request
6
6
  from pydantic import (
@@ -72,17 +72,17 @@ class APISchemaMixin:
72
72
 
73
73
 
74
74
  class APINodeSchema(NodeSchema, APISchemaMixin):
75
- api_kind: Optional[str] = Field(default=None, alias="kind", validate_default=True)
75
+ api_kind: str | None = Field(default=None, alias="kind", validate_default=True)
76
76
  hash: str
77
77
 
78
78
 
79
79
  class APIGenericSchema(GenericSchema, APISchemaMixin):
80
- api_kind: Optional[str] = Field(default=None, alias="kind", validate_default=True)
80
+ api_kind: str | None = Field(default=None, alias="kind", validate_default=True)
81
81
  hash: str
82
82
 
83
83
 
84
84
  class APIProfileSchema(ProfileSchema, APISchemaMixin):
85
- api_kind: Optional[str] = Field(default=None, alias="kind", validate_default=True)
85
+ api_kind: str | None = Field(default=None, alias="kind", validate_default=True)
86
86
  hash: str
87
87
 
88
88
 
@@ -103,16 +103,16 @@ class SchemasLoadAPI(BaseModel):
103
103
 
104
104
 
105
105
  class JSONSchema(BaseModel):
106
- title: Optional[str] = Field(None, description="Title of the schema")
107
- description: Optional[str] = Field(None, description="Description of the schema")
106
+ title: str | None = Field(None, description="Title of the schema")
107
+ description: str | None = Field(None, description="Description of the schema")
108
108
  type: str = Field(..., description="Type of the schema element (e.g., 'object', 'array', 'string')")
109
- properties: Optional[dict[str, Any]] = Field(None, description="Properties of the object if type is 'object'")
110
- items: Optional[Union[dict[str, Any], list[dict[str, Any]]]] = Field(
109
+ properties: dict[str, Any] | None = Field(None, description="Properties of the object if type is 'object'")
110
+ items: dict[str, Any] | list[dict[str, Any]] | None = Field(
111
111
  None, description="Items of the array if type is 'array'"
112
112
  )
113
- required: Optional[list[str]] = Field(None, description="List of required properties if type is 'object'")
114
- schema_spec: Optional[str] = Field(None, alias="$schema", description="Schema version identifier")
115
- additional_properties: Optional[Union[bool, dict[str, Any]]] = Field(
113
+ required: list[str] | None = Field(None, description="List of required properties if type is 'object'")
114
+ schema_spec: str | None = Field(None, alias="$schema", description="Schema version identifier")
115
+ additional_properties: bool | dict[str, Any] | None = Field(
116
116
  None, description="Specifies whether additional properties are allowed", alias="additionalProperties"
117
117
  )
118
118
 
@@ -128,13 +128,26 @@ class SchemaUpdate(BaseModel):
128
128
  return self.hash != self.previous_hash
129
129
 
130
130
 
131
+ def _merge_candidate_schemas(schemas: Sequence[SchemaRoot]) -> SchemaRoot:
132
+ """Merge multiple schemas into one suitable to be loaded."""
133
+ if not schemas:
134
+ raise ValueError("Cannot merge an empty list of schemas")
135
+
136
+ merged = schemas[0]
137
+ for schema in schemas[1:]:
138
+ merged = merged.merge(schema=schema)
139
+
140
+ return merged
141
+
142
+
131
143
  def evaluate_candidate_schemas(
132
144
  branch_schema: SchemaBranch, schemas_to_evaluate: SchemasLoadAPI
133
145
  ) -> tuple[SchemaBranch, SchemaUpdateValidationResult]:
134
146
  candidate_schema = branch_schema.duplicate()
147
+ schema = _merge_candidate_schemas(schemas=schemas_to_evaluate.schemas)
148
+
135
149
  try:
136
- for schema in schemas_to_evaluate.schemas:
137
- candidate_schema.load_schema(schema=schema)
150
+ candidate_schema.load_schema(schema=schema)
138
151
  candidate_schema.process()
139
152
 
140
153
  schema_diff = branch_schema.diff(other=candidate_schema)
@@ -152,7 +165,9 @@ def evaluate_candidate_schemas(
152
165
 
153
166
  @router.get("")
154
167
  async def get_schema(
155
- branch: Branch = Depends(get_branch_dep), namespaces: Union[list[str], None] = Query(default=None)
168
+ branch: Branch = Depends(get_branch_dep),
169
+ namespaces: list[str] | None = Query(default=None),
170
+ _: AccountSession = Depends(get_current_user),
156
171
  ) -> SchemaReadAPI:
157
172
  log.debug("schema_request", branch=branch.name)
158
173
  schema_branch = registry.schema.get_schema_branch(name=branch.name)
@@ -180,7 +195,9 @@ async def get_schema(
180
195
 
181
196
 
182
197
  @router.get("/summary")
183
- async def get_schema_summary(branch: Branch = Depends(get_branch_dep)) -> SchemaBranchHash:
198
+ async def get_schema_summary(
199
+ branch: Branch = Depends(get_branch_dep), _: AccountSession = Depends(get_current_user)
200
+ ) -> SchemaBranchHash:
184
201
  log.debug("schema_summary_request", branch=branch.name)
185
202
  schema_branch = registry.schema.get_schema_branch(name=branch.name)
186
203
  return schema_branch.get_hash_full()
@@ -188,13 +205,13 @@ async def get_schema_summary(branch: Branch = Depends(get_branch_dep)) -> Schema
188
205
 
189
206
  @router.get("/{schema_kind}")
190
207
  async def get_schema_by_kind(
191
- schema_kind: str, branch: Branch = Depends(get_branch_dep)
192
- ) -> Union[APIProfileSchema, APINodeSchema, APIGenericSchema]:
208
+ schema_kind: str, branch: Branch = Depends(get_branch_dep), _: AccountSession = Depends(get_current_user)
209
+ ) -> APIProfileSchema | APINodeSchema | APIGenericSchema:
193
210
  log.debug("schema_kind_request", branch=branch.name)
194
211
 
195
212
  schema = registry.schema.get(name=schema_kind, branch=branch, duplicate=False)
196
213
 
197
- api_schema: dict[str, type[Union[APIProfileSchema, APINodeSchema, APIGenericSchema]]] = {
214
+ api_schema: dict[str, type[APIProfileSchema | APINodeSchema | APIGenericSchema]] = {
198
215
  "profile": APIProfileSchema,
199
216
  "node": APINodeSchema,
200
217
  "generic": APIGenericSchema,
@@ -212,7 +229,9 @@ async def get_schema_by_kind(
212
229
 
213
230
 
214
231
  @router.get("/json_schema/{schema_kind}")
215
- async def get_json_schema_by_kind(schema_kind: str, branch: Branch = Depends(get_branch_dep)) -> JSONSchema:
232
+ async def get_json_schema_by_kind(
233
+ schema_kind: str, branch: Branch = Depends(get_branch_dep), _: AccountSession = Depends(get_current_user)
234
+ ) -> JSONSchema:
216
235
  log.debug("json_schema_kind_request", branch=branch.name)
217
236
 
218
237
  fields: dict[str, Any] = {}
@@ -368,7 +387,7 @@ async def check_schema(
368
387
  request: Request,
369
388
  schemas: SchemasLoadAPI,
370
389
  branch: Branch = Depends(get_branch_dep),
371
- _: Any = Depends(get_current_user),
390
+ _: AccountSession = Depends(get_current_user),
372
391
  ) -> JSONResponse:
373
392
  service: InfrahubServices = request.app.state.service
374
393
  log.info("schema_check_request", branch=branch.name)
infrahub/api/storage.py CHANGED
@@ -1,4 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  import hashlib
4
+ from typing import TYPE_CHECKING
2
5
 
3
6
  from fastapi import APIRouter, Depends, File, Response, UploadFile
4
7
  from infrahub_sdk.uuidt import UUIDT
@@ -8,6 +11,9 @@ from infrahub.api.dependencies import get_current_user
8
11
  from infrahub.core import registry
9
12
  from infrahub.log import get_logger
10
13
 
14
+ if TYPE_CHECKING:
15
+ from infrahub.auth import AccountSession
16
+
11
17
  log = get_logger()
12
18
  router = APIRouter(prefix="/storage")
13
19
 
@@ -22,10 +28,7 @@ class UploadContentPayload(BaseModel):
22
28
 
23
29
 
24
30
  @router.get("/object/{identifier:str}")
25
- def get_file(
26
- identifier: str,
27
- _: str = Depends(get_current_user),
28
- ) -> Response:
31
+ def get_file(identifier: str, _: AccountSession = Depends(get_current_user)) -> Response:
29
32
  content = registry.storage.retrieve(identifier=identifier)
30
33
  return Response(content=content)
31
34
 
@@ -48,10 +51,7 @@ def upload_content(
48
51
 
49
52
 
50
53
  @router.post("/upload/file")
51
- def upload_file(
52
- file: UploadFile = File(...),
53
- _: str = Depends(get_current_user),
54
- ) -> UploadResponse:
54
+ def upload_file(file: UploadFile = File(...), _: AccountSession = Depends(get_current_user)) -> UploadResponse:
55
55
  # TODO need to optimized how we read the content of the file, especially if the file is really large
56
56
  # Check this discussion for more details
57
57
  # https://stackoverflow.com/questions/63048825/how-to-upload-file-using-fastapi
@@ -27,6 +27,7 @@ from infrahub.transformations.models import TransformJinjaTemplateData, Transfor
27
27
  from infrahub.workflows.catalogue import TRANSFORM_JINJA2_RENDER, TRANSFORM_PYTHON_RENDER
28
28
 
29
29
  if TYPE_CHECKING:
30
+ from infrahub.auth import AccountSession
30
31
  from infrahub.services import InfrahubServices
31
32
 
32
33
  router = APIRouter()
@@ -38,7 +39,7 @@ async def transform_python(
38
39
  transform_id: str,
39
40
  db: InfrahubDatabase = Depends(get_db),
40
41
  branch_params: BranchParams = Depends(get_branch_params),
41
- _: str = Depends(get_current_user),
42
+ _: AccountSession = Depends(get_current_user),
42
43
  ) -> JSONResponse:
43
44
  params = {key: value for key, value in request.query_params.items() if key not in ["branch", "at"]}
44
45
 
@@ -97,7 +98,7 @@ async def transform_jinja2(
97
98
  transform_id: str = Path(description="ID or Name of the Jinja2 Transform to render"),
98
99
  db: InfrahubDatabase = Depends(get_db),
99
100
  branch_params: BranchParams = Depends(get_branch_params),
100
- _: str = Depends(get_current_user),
101
+ _: AccountSession = Depends(get_current_user),
101
102
  ) -> PlainTextResponse:
102
103
  params = {key: value for key, value in request.query_params.items() if key not in ["branch", "at"]}
103
104
 
infrahub/auth.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import uuid
4
4
  from datetime import datetime, timedelta, timezone
5
5
  from enum import Enum
6
- from typing import TYPE_CHECKING, Callable
6
+ from typing import TYPE_CHECKING
7
7
 
8
8
  import bcrypt
9
9
  import jwt
@@ -233,29 +233,6 @@ async def validate_api_key(db: InfrahubDatabase, token: str) -> AccountSession:
233
233
  return AccountSession(account_id=account_id, auth_type=AuthType.API)
234
234
 
235
235
 
236
- def _validate_update_account(account_session: AccountSession, node_id: str, fields: list[str]) -> None:
237
- if account_session.account_id != node_id:
238
- # A regular account is not allowed to modify another account
239
- raise PermissionError("You are not allowed to modify this account")
240
-
241
- allowed_fields = ["description", "label", "password"]
242
- for field in fields:
243
- if field not in allowed_fields:
244
- raise PermissionError(f"You are not allowed to modify '{field}'")
245
-
246
-
247
- def validate_mutation_permissions_update_node(
248
- operation: str, node_id: str, account_session: AccountSession, fields: list[str]
249
- ) -> None:
250
- validation_map: dict[str, Callable[[AccountSession, str, list[str]], None]] = {
251
- f"{InfrahubKind.ACCOUNT}Update": _validate_update_account,
252
- f"{InfrahubKind.ACCOUNT}Upsert": _validate_update_account,
253
- }
254
-
255
- if validator := validation_map.get(operation):
256
- validator(account_session, node_id, fields)
257
-
258
-
259
236
  async def invalidate_refresh_token(db: InfrahubDatabase, token_id: str) -> None:
260
237
  refresh_token = await NodeManager.get_one(id=token_id, db=db)
261
238
  if refresh_token:
infrahub/cli/__init__.py CHANGED
@@ -20,7 +20,7 @@ app = typer.Typer(name="Infrahub CLI", pretty_exceptions_enable=False)
20
20
  @app.callback()
21
21
  def common(ctx: typer.Context) -> None:
22
22
  """Infrahub CLI"""
23
- ctx.obj = CliContext(database_class=InfrahubDatabase)
23
+ ctx.obj = CliContext()
24
24
 
25
25
 
26
26
  app.add_typer(server_app, name="server")
infrahub/cli/context.py CHANGED
@@ -1,18 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import TYPE_CHECKING
5
4
 
6
- from infrahub.database import get_db
7
-
8
- if TYPE_CHECKING:
9
- from infrahub.database import InfrahubDatabase
5
+ from infrahub.database import InfrahubDatabase, get_db
10
6
 
11
7
 
12
8
  @dataclass
13
9
  class CliContext:
14
- database_class: type[InfrahubDatabase]
15
10
  application: str = "infrahub.server:app"
16
11
 
17
- async def get_db(self, retry: int = 0) -> InfrahubDatabase:
18
- return self.database_class(driver=await get_db(retry=retry))
12
+ # This method is inherited for Infrahub Enterprise.
13
+ @staticmethod
14
+ async def init_db(retry: int) -> InfrahubDatabase:
15
+ return InfrahubDatabase(driver=await get_db(retry=retry))