mrok 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. mrok/agent/devtools/inspector/__main__.py +3 -23
  2. mrok/agent/devtools/inspector/app.py +407 -112
  3. mrok/agent/devtools/inspector/utils.py +149 -0
  4. mrok/cli/commands/admin/bootstrap.py +2 -2
  5. mrok/cli/commands/admin/register/extensions.py +7 -9
  6. mrok/cli/commands/admin/register/instances.py +13 -16
  7. mrok/cli/commands/admin/unregister/extensions.py +7 -11
  8. mrok/cli/commands/admin/unregister/instances.py +12 -12
  9. mrok/cli/commands/agent/run/asgi.py +1 -1
  10. mrok/cli/commands/frontend/run.py +1 -1
  11. mrok/cli/main.py +17 -1
  12. mrok/cli/utils.py +26 -0
  13. mrok/conf.py +15 -7
  14. mrok/constants.py +21 -0
  15. mrok/controller/app.py +12 -10
  16. mrok/controller/auth/__init__.py +11 -0
  17. mrok/controller/auth/backends.py +60 -0
  18. mrok/controller/auth/base.py +38 -0
  19. mrok/controller/auth/manager.py +31 -0
  20. mrok/controller/auth/registry.py +17 -0
  21. mrok/frontend/app.py +94 -26
  22. mrok/frontend/main.py +8 -5
  23. mrok/frontend/middleware.py +35 -0
  24. mrok/frontend/utils.py +83 -0
  25. mrok/logging.py +24 -22
  26. mrok/proxy/app.py +13 -5
  27. mrok/proxy/middleware.py +7 -8
  28. mrok/proxy/models.py +36 -10
  29. mrok/proxy/ziticorn.py +8 -17
  30. mrok/ziti/api.py +4 -4
  31. mrok/ziti/bootstrap.py +0 -5
  32. mrok/ziti/identities.py +11 -10
  33. mrok/ziti/services.py +6 -6
  34. {mrok-0.6.0.dist-info → mrok-0.8.0.dist-info}/METADATA +9 -3
  35. {mrok-0.6.0.dist-info → mrok-0.8.0.dist-info}/RECORD +38 -35
  36. mrok/agent/devtools/__main__.py +0 -34
  37. mrok/cli/commands/agent/utils.py +0 -5
  38. mrok/controller/auth.py +0 -87
  39. mrok/proxy/constants.py +0 -22
  40. mrok/proxy/utils.py +0 -90
  41. {mrok-0.6.0.dist-info → mrok-0.8.0.dist-info}/WHEEL +0 -0
  42. {mrok-0.6.0.dist-info → mrok-0.8.0.dist-info}/entry_points.txt +0 -0
  43. {mrok-0.6.0.dist-info → mrok-0.8.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,149 @@
1
+ from collections.abc import Generator
2
+ from dataclasses import dataclass
3
+ from io import BytesIO
4
+
5
+ from multipart import MultipartParser
6
+
7
+ TEXTUAL_CONTENT_TYPES = {
8
+ "application/json",
9
+ "application/xml",
10
+ "application/javascript",
11
+ "application/x-www-form-urlencoded",
12
+ }
13
+
14
+ TEXTUAL_PREFIXES = ("text/",)
15
+
16
+ CONTENT_TYPE_TO_LANGUAGE = {
17
+ "application/json": "json",
18
+ "application/ld+json": "json",
19
+ "application/problem+json": "json",
20
+ "application/schema+json": "json",
21
+ "application/xml": "xml",
22
+ "text/xml": "xml",
23
+ "application/xhtml+xml": "html",
24
+ "text/html": "html",
25
+ "text/css": "css",
26
+ "application/javascript": "javascript",
27
+ "application/x-javascript": "javascript",
28
+ "text/javascript": "javascript",
29
+ "application/ecmascript": "javascript",
30
+ "text/markdown": "markdown",
31
+ "text/x-markdown": "markdown",
32
+ "application/yaml": "yaml",
33
+ "application/x-yaml": "yaml",
34
+ "text/yaml": "yaml",
35
+ "application/toml": "toml",
36
+ "application/x-toml": "toml",
37
+ "application/sql": "sql",
38
+ "text/x-sql": "sql",
39
+ "application/java": "java",
40
+ "text/x-java-source": "java",
41
+ "application/python": "python",
42
+ "text/x-python": "python",
43
+ "application/x-python-code": "python",
44
+ "application/rust": "rust",
45
+ "text/x-rust": "rust",
46
+ "application/go": "go",
47
+ "text/x-go": "go",
48
+ "application/bash": "bash",
49
+ "application/x-sh": "bash",
50
+ "text/x-shellscript": "bash",
51
+ "application/regex": "regex",
52
+ "text/x-regex": "regex",
53
+ }
54
+
55
+
56
+ @dataclass
57
+ class ContentTypeInfo:
58
+ content_type: str
59
+ binary: bool
60
+ charset: str | None = None
61
+ boundary: str | None = None
62
+
63
+
64
+ def parse_content_type(content_type_header: str) -> ContentTypeInfo:
65
+ parts = content_type_header.split(";")
66
+ content_type = parts[0].strip().lower()
67
+
68
+ charset = None
69
+ boundary = None
70
+
71
+ for part in parts[1:]:
72
+ part = part.strip()
73
+ if "=" in part:
74
+ key, value = part.split("=", 1)
75
+ key = key.strip().lower()
76
+ value = value.strip().strip('"')
77
+ if key == "charset":
78
+ charset = value
79
+ elif key == "boundary":
80
+ boundary = value
81
+
82
+ binary = not is_textual(content_type)
83
+
84
+ if charset is None and not binary:
85
+ charset = "utf-8"
86
+
87
+ return ContentTypeInfo(
88
+ content_type=content_type, binary=binary, charset=charset, boundary=boundary
89
+ )
90
+
91
+
92
+ def parse_form_data(data: bytes, boundary: str) -> Generator[tuple[str, str]]:
93
+ parser = MultipartParser(BytesIO(data), boundary)
94
+ for part in parser:
95
+ if is_textual(part.content_type):
96
+ yield part.name, part.value
97
+ continue
98
+ yield part.name, "<binary>"
99
+
100
+
101
+ def is_textual(content_type: str) -> bool:
102
+ ct = content_type.lower()
103
+ if ct in TEXTUAL_CONTENT_TYPES:
104
+ return True
105
+ if any(ct.startswith(p) for p in TEXTUAL_PREFIXES):
106
+ return True
107
+ return False
108
+
109
+
110
+ def build_tree(node, data):
111
+ if isinstance(data, dict):
112
+ for key, value in data.items():
113
+ child = node.add(str(key))
114
+ build_tree(child, value)
115
+ elif isinstance(data, list):
116
+ for index, value in enumerate(data):
117
+ child = node.add(f"[{index}]")
118
+ build_tree(child, value)
119
+ else:
120
+ node.add(repr(data))
121
+
122
+
123
+ def hexdump(data, width=16):
124
+ lines = []
125
+ for i in range(0, len(data), width):
126
+ chunk = data[i : i + width]
127
+ hex_part = " ".join(f"{b:02x}" for b in chunk)
128
+ ascii_part = "".join(chr(b) if 32 <= b <= 126 else "." for b in chunk)
129
+ lines.append(f"{hex_part:<{width * 3}} {ascii_part}")
130
+ return "\n".join(lines)
131
+
132
+
133
+ def humanize_bytes(num_bytes: int) -> tuple[float, str]: # type: ignore[return-value]
134
+ if num_bytes < 0:
135
+ raise ValueError("num_bytes must be non-negative")
136
+
137
+ units = ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]
138
+ value = float(num_bytes)
139
+
140
+ for unit in units:
141
+ if value < 1024 or unit == units[-1]:
142
+ return round(value, 2), unit
143
+ value /= 1024
144
+
145
+
146
+ def get_highlighter_language_by_content_type(content_type: str) -> str | None:
147
+ if content_type in CONTENT_TYPE_TO_LANGUAGE:
148
+ return CONTENT_TYPE_TO_LANGUAGE[content_type]
149
+ return None
@@ -22,8 +22,8 @@ async def bootstrap(
22
22
  return await bootstrap_identity(
23
23
  mgmt_api,
24
24
  client_api,
25
- settings.proxy.identity,
26
- settings.proxy.mode,
25
+ settings.frontend.identity,
26
+ settings.frontend.mode,
27
27
  forced,
28
28
  tags,
29
29
  )
@@ -5,8 +5,8 @@ import typer
5
5
  from rich import print
6
6
 
7
7
  from mrok.cli.commands.admin.utils import parse_tags
8
- from mrok.conf import Settings
9
- from mrok.constants import RE_EXTENSION_ID
8
+ from mrok.cli.utils import validate_extension_id
9
+ from mrok.conf import Settings, get_settings
10
10
  from mrok.ziti.api import ZitiManagementAPI
11
11
  from mrok.ziti.services import register_service
12
12
 
@@ -16,18 +16,16 @@ async def do_register(settings: Settings, extension_id: str, tags: list[str] | N
16
16
  await register_service(settings, api, extension_id, tags=parse_tags(tags))
17
17
 
18
18
 
19
- def validate_extension_id(extension_id: str) -> str:
20
- if not RE_EXTENSION_ID.fullmatch(extension_id):
21
- raise typer.BadParameter("it must match EXT-xxxx-yyyy (case-insensitive)")
22
- return extension_id
23
-
24
-
25
19
  def register(app: typer.Typer) -> None:
20
+ settings = get_settings()
21
+
26
22
  @app.command("extension")
27
23
  def register_extension(
28
24
  ctx: typer.Context,
29
25
  extension_id: str = typer.Argument(
30
- ..., callback=validate_extension_id, help="Extension ID in format EXT-xxxx-yyyy"
26
+ ...,
27
+ callback=validate_extension_id,
28
+ help=f"Extension ID in the format {settings.identifiers.extension.format}",
31
29
  ),
32
30
  tags: Annotated[
33
31
  list[str] | None,
@@ -6,8 +6,11 @@ from typing import Annotated
6
6
  import typer
7
7
 
8
8
  from mrok.cli.commands.admin.utils import parse_tags
9
- from mrok.conf import Settings
10
- from mrok.constants import RE_EXTENSION_ID, RE_INSTANCE_ID
9
+ from mrok.cli.utils import (
10
+ validate_extension_id,
11
+ validate_instance_id,
12
+ )
13
+ from mrok.conf import Settings, get_settings
11
14
  from mrok.ziti.api import ZitiClientAPI, ZitiManagementAPI
12
15
  from mrok.ziti.identities import register_identity
13
16
 
@@ -21,27 +24,21 @@ async def do_register(
21
24
  )
22
25
 
23
26
 
24
- def validate_extension_id(extension_id: str):
25
- if not RE_EXTENSION_ID.fullmatch(extension_id):
26
- raise typer.BadParameter("it must match EXT-xxxx-yyyy (case-insensitive)")
27
- return extension_id
28
-
29
-
30
- def validate_instance_id(instance_id: str):
31
- if not RE_INSTANCE_ID.fullmatch(instance_id):
32
- raise typer.BadParameter("it must match INS-xxxx-yyyy-zzzz (case-insensitive)")
33
- return instance_id
34
-
35
-
36
27
  def register(app: typer.Typer) -> None:
28
+ settings = get_settings()
29
+
37
30
  @app.command("instance")
38
31
  def register_instance(
39
32
  ctx: typer.Context,
40
33
  extension_id: str = typer.Argument(
41
- ..., callback=validate_extension_id, help="Extension ID in format EXT-xxxx-yyyy"
34
+ ...,
35
+ callback=validate_extension_id,
36
+ help=f"Extension ID in the format {settings.identifiers.extension.format}",
42
37
  ),
43
38
  instance_id: str = typer.Argument(
44
- ..., callback=validate_instance_id, help="Instance ID in format INS-xxxx-yyyy-zzzz"
39
+ ...,
40
+ callback=validate_instance_id,
41
+ help=f"Instance ID in the format {settings.identifiers.instance.format}",
45
42
  ),
46
43
  output: Path = typer.Argument(
47
44
  ...,
@@ -1,32 +1,28 @@
1
1
  import asyncio
2
- import re
3
2
 
4
3
  import typer
5
4
 
6
- from mrok.conf import Settings
5
+ from mrok.cli.utils import validate_extension_id
6
+ from mrok.conf import Settings, get_settings
7
7
  from mrok.ziti.api import ZitiManagementAPI
8
8
  from mrok.ziti.services import unregister_service
9
9
 
10
- RE_EXTENSION_ID = re.compile(r"(?i)EXT-\d{4}-\d{4}")
11
-
12
10
 
13
11
  async def do_unregister(settings: Settings, extension_id: str):
14
12
  async with ZitiManagementAPI(settings) as api:
15
13
  await unregister_service(settings, api, extension_id)
16
14
 
17
15
 
18
- def validate_extension_id(extension_id: str):
19
- if not RE_EXTENSION_ID.fullmatch(extension_id):
20
- raise typer.BadParameter("ext_id must match EXT-xxxx-yyyy (case-insensitive)")
21
- return extension_id
22
-
23
-
24
16
  def register(app: typer.Typer) -> None:
17
+ settings = get_settings()
18
+
25
19
  @app.command("extension")
26
20
  def unregister_extension(
27
21
  ctx: typer.Context,
28
22
  extension_id: str = typer.Argument(
29
- ..., callback=validate_extension_id, help="Extension ID in format EXT-xxxx-yyyy"
23
+ ...,
24
+ callback=validate_extension_id,
25
+ help=f"Extension ID in the format {settings.identifiers.extension.format}",
30
26
  ),
31
27
  ):
32
28
  """Unregister a new Extension in OpenZiti (service)."""
@@ -1,34 +1,34 @@
1
1
  import asyncio
2
- import re
3
2
 
4
3
  import typer
5
4
 
6
- from mrok.conf import Settings
5
+ from mrok.cli.utils import validate_extension_id, validate_instance_id
6
+ from mrok.conf import Settings, get_settings
7
7
  from mrok.ziti.api import ZitiManagementAPI
8
8
  from mrok.ziti.identities import unregister_identity
9
9
 
10
- RE_EXTENSION_ID = re.compile(r"(?i)EXT-\d{4}-\d{4}")
11
-
12
10
 
13
11
  async def do_unregister(settings: Settings, extension_id: str, instance_id: str):
14
12
  async with ZitiManagementAPI(settings) as api:
15
13
  await unregister_identity(settings, api, extension_id, instance_id)
16
14
 
17
15
 
18
- def validate_extension_id(extension_id: str):
19
- if not RE_EXTENSION_ID.fullmatch(extension_id):
20
- raise typer.BadParameter("ext_id must match EXT-xxxx-yyyy (case-insensitive)")
21
- return extension_id
22
-
23
-
24
16
  def register(app: typer.Typer) -> None:
17
+ settings = get_settings()
18
+
25
19
  @app.command("instance")
26
20
  def unregister_instance(
27
21
  ctx: typer.Context,
28
22
  extension_id: str = typer.Argument(
29
- ..., callback=validate_extension_id, help="Extension ID in format EXT-xxxx-yyyy"
23
+ ...,
24
+ callback=validate_extension_id,
25
+ help=f"Extension ID in the format {settings.identifiers.extension.format}",
26
+ ),
27
+ instance_id: str = typer.Argument(
28
+ ...,
29
+ callback=validate_instance_id,
30
+ help=f"Instance ID in the format {settings.identifiers.instance.format}",
30
31
  ),
31
- instance_id: str = typer.Argument(..., help="Instance ID"),
32
32
  ):
33
33
  """Register a new Extension Instance in OpenZiti (identity)."""
34
34
  asyncio.run(do_unregister(ctx.obj, extension_id, instance_id))
@@ -12,8 +12,8 @@ default_workers = number_of_workers()
12
12
  def register(app: typer.Typer) -> None:
13
13
  @app.command("asgi")
14
14
  def run_asgi(
15
- app: Annotated[str, typer.Argument(..., help="ASGI application")],
16
15
  identity_file: Annotated[Path, typer.Argument(..., help="Identity json file")],
16
+ app: Annotated[str, typer.Argument(..., help="ASGI application")],
17
17
  workers: Annotated[
18
18
  int,
19
19
  typer.Option(
@@ -30,7 +30,7 @@ def register(app: typer.Typer) -> None:
30
30
  int,
31
31
  typer.Option(
32
32
  "--port",
33
- "-P",
33
+ "-p",
34
34
  help="Port to bind to. Default: 8000",
35
35
  show_default=True,
36
36
  ),
mrok/cli/main.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import inspect
2
2
  import sys
3
+ from typing import Annotated
3
4
 
4
5
  import typer
5
6
  from pyfiglet import Figlet
@@ -79,11 +80,23 @@ for name, module in inspect.getmembers(commands):
79
80
  elif hasattr(module, "app"): # pragma: no branch
80
81
  app.add_typer(module.app, name=name.replace("_", "-"))
81
82
 
83
+ _debug_mode = False
84
+
82
85
 
83
86
  @app.callback()
84
87
  def main(
85
88
  ctx: typer.Context,
89
+ debug: Annotated[
90
+ bool,
91
+ typer.Option(
92
+ "--debug",
93
+ help="Run the CLI in debug mode",
94
+ show_default=True,
95
+ ),
96
+ ] = False,
86
97
  ):
98
+ global _debug_mode
99
+ _debug_mode = debug
87
100
  settings = get_settings()
88
101
  setup_logging(settings, cli_mode=True)
89
102
  ctx.obj = settings
@@ -93,5 +106,8 @@ def run():
93
106
  try:
94
107
  app()
95
108
  except Exception as e:
96
- err_console.print(f"[bold red]Error:[/bold red] {e}")
109
+ if _debug_mode:
110
+ raise
111
+ message = str(e) or "Unexpected error. Debug it with --debug"
112
+ err_console.print(f"[bold red]Error:[/bold red] {message}")
97
113
  sys.exit(-1)
mrok/cli/utils.py CHANGED
@@ -1,5 +1,31 @@
1
1
  import multiprocessing
2
+ import re
3
+
4
+ import typer
5
+
6
+ from mrok.conf import get_settings
2
7
 
3
8
 
4
9
  def number_of_workers() -> int:
5
10
  return (multiprocessing.cpu_count() * 2) + 1
11
+
12
+
13
+ def validate_identifier(regex_exp: str, format: str, identifier: str) -> str:
14
+ match = re.fullmatch(regex_exp, identifier)
15
+ if not match:
16
+ raise typer.BadParameter(f"it must match {format}")
17
+ return identifier
18
+
19
+
20
+ def validate_extension_id(extension_id: str) -> str:
21
+ settings = get_settings()
22
+ return validate_identifier(
23
+ settings.identifiers.extension.regex, settings.identifiers.extension.format, extension_id
24
+ )
25
+
26
+
27
+ def validate_instance_id(instance_id: str) -> str:
28
+ settings = get_settings()
29
+ return validate_identifier(
30
+ settings.identifiers.instance.regex, settings.identifiers.instance.format, instance_id
31
+ )
mrok/conf.py CHANGED
@@ -7,19 +7,27 @@ DEFAULT_SETTINGS = {
7
7
  "debug": False,
8
8
  "rich": False,
9
9
  },
10
- "PROXY": {
10
+ "FRONTEND": {
11
11
  "identity": "public",
12
12
  "mode": "zrok",
13
13
  },
14
14
  "ZITI": {
15
15
  "ssl_verify": False,
16
16
  },
17
- "PAGINATION": {"limit": 50},
18
- "SIDECAR": {
19
- "textual_port": 4040,
20
- "store_port": 5051,
21
- "store_size": 1000,
22
- "textual_command": "python mrok/agent/sidecar/inspector.py",
17
+ "CONTROLLER": {
18
+ "pagination": {"limit": 50},
19
+ },
20
+ "IDENTIFIERS": {
21
+ "extension": {
22
+ "regex": "(?i)EXT-\\d{4}-\\d{4}",
23
+ "format": "EXT-xxxx-yyyy",
24
+ "example": "EXT-2000-1000",
25
+ },
26
+ "instance": {
27
+ "regex": "(?i)INS-\\d{4}-\\d{4}-\\d{4}",
28
+ "format": "INS-xxxx-yyyy-zzzz",
29
+ "example": "INS-2004-2000-3000",
30
+ },
23
31
  },
24
32
  }
25
33
 
mrok/constants.py CHANGED
@@ -2,3 +2,24 @@ import re
2
2
 
3
3
  RE_EXTENSION_ID = re.compile(r"(?i)EXT-\d{4}-\d{4}")
4
4
  RE_INSTANCE_ID = re.compile(r"(?i)INS-\d{4}-\d{4}-\d{4}")
5
+
6
+
7
+ BINARY_CONTENT_TYPES = {
8
+ "application/octet-stream",
9
+ "application/pdf",
10
+ }
11
+
12
+ BINARY_PREFIXES = (
13
+ "image/",
14
+ "video/",
15
+ "audio/",
16
+ )
17
+
18
+ TEXTUAL_CONTENT_TYPES = {
19
+ "application/json",
20
+ "application/xml",
21
+ "application/javascript",
22
+ "application/x-www-form-urlencoded",
23
+ }
24
+
25
+ TEXTUAL_PREFIXES = ("text/",)
mrok/controller/app.py CHANGED
@@ -5,8 +5,8 @@ import fastapi_pagination
5
5
  from fastapi import Depends, FastAPI
6
6
  from fastapi.routing import APIRoute, APIRouter
7
7
 
8
- from mrok.conf import get_settings
9
- from mrok.controller.auth import authenticate
8
+ from mrok.conf import Settings, get_settings
9
+ from mrok.controller.auth import HTTPAuthManager
10
10
  from mrok.controller.openapi import generate_openapi_spec
11
11
  from mrok.controller.routes.extensions import router as extensions_router
12
12
  from mrok.controller.routes.instances import router as instances_router
@@ -36,7 +36,8 @@ def setup_custom_serialization(router: APIRouter):
36
36
  api_route.response_model_exclude_none = True
37
37
 
38
38
 
39
- def setup_app():
39
+ def setup_app(settings: Settings):
40
+ auth_manager = HTTPAuthManager(settings.controller.auth)
40
41
  app = FastAPI(
41
42
  title="mrok Controller API",
42
43
  description="API to orchestrate OpenZiti for Extensions.",
@@ -49,22 +50,23 @@ def setup_app():
49
50
 
50
51
  setup_custom_serialization(extensions_router)
51
52
 
52
- # TODO: Add healthcheck
53
+ @app.get("/healthcheck")
54
+ async def healthcheck():
55
+ return {"status": "healthy"}
56
+
53
57
  app.include_router(
54
58
  extensions_router,
55
59
  prefix="/extensions",
56
- dependencies=[Depends(authenticate)],
60
+ dependencies=[Depends(auth_manager)],
57
61
  )
58
62
  app.include_router(
59
63
  instances_router,
60
64
  prefix="/instances",
61
- dependencies=[Depends(authenticate)],
65
+ dependencies=[Depends(auth_manager)],
62
66
  )
63
67
 
64
- settings = get_settings()
65
-
66
- app.openapi = partial(generate_openapi_spec, app, settings)
68
+ app.openapi = partial(generate_openapi_spec, app, settings) # type: ignore[method-assign]
67
69
  return app
68
70
 
69
71
 
70
- app = setup_app()
72
+ app = setup_app(get_settings())
@@ -0,0 +1,11 @@
1
+ from mrok.controller.auth.backends import OIDCJWTAuthenticationBackend # noqa: F401
2
+ from mrok.controller.auth.base import AuthIdentity, BaseHTTPAuthBackend
3
+ from mrok.controller.auth.manager import HTTPAuthManager
4
+ from mrok.controller.auth.registry import register_authentication_backend
5
+
6
+ __all__ = [
7
+ "AuthIdentity",
8
+ "BaseHTTPAuthBackend",
9
+ "HTTPAuthManager",
10
+ "register_authentication_backend",
11
+ ]
@@ -0,0 +1,60 @@
1
+ import logging
2
+
3
+ import httpx
4
+ import jwt
5
+ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
6
+ from fastapi.security.http import HTTPBase
7
+
8
+ from mrok.controller.auth.base import UNAUTHORIZED_EXCEPTION, AuthIdentity, BaseHTTPAuthBackend
9
+ from mrok.controller.auth.registry import register_authentication_backend
10
+
11
+ logger = logging.getLogger("mrok.controller")
12
+
13
+
14
+ @register_authentication_backend("oidc")
15
+ class OIDCJWTAuthenticationBackend(BaseHTTPAuthBackend):
16
+ def init_scheme(self) -> HTTPBase:
17
+ return HTTPBearer(auto_error=False)
18
+
19
+ async def authenticate(self, credentials: HTTPAuthorizationCredentials) -> AuthIdentity | None:
20
+ async with httpx.AsyncClient() as client:
21
+ try:
22
+ config_resp = await client.get(self.config.config_url)
23
+ config_resp.raise_for_status()
24
+ config = config_resp.json()
25
+ issuer = config["issuer"]
26
+ jwks_uri = config["jwks_uri"]
27
+
28
+ jwks_resp = await client.get(jwks_uri)
29
+ jwks_resp.raise_for_status()
30
+ jwks = jwks_resp.json()
31
+
32
+ header = jwt.get_unverified_header(credentials.credentials)
33
+ kid = header["kid"]
34
+
35
+ key_data = next((k for k in jwks["keys"] if k["kid"] == kid), None)
36
+ except Exception:
37
+ logger.exception("Error fetching openid-config/jwks")
38
+ raise UNAUTHORIZED_EXCEPTION
39
+ if key_data is None:
40
+ logger.error("Key ID not found in JWKS")
41
+ raise UNAUTHORIZED_EXCEPTION
42
+
43
+ try:
44
+ payload = jwt.decode(
45
+ credentials.credentials,
46
+ jwt.PyJWK(key_data),
47
+ algorithms=[header["alg"]],
48
+ issuer=issuer,
49
+ audience=self.config.audience,
50
+ )
51
+ return AuthIdentity(
52
+ subject=payload["sub"],
53
+ metadata=payload,
54
+ )
55
+ except jwt.InvalidKeyError as e:
56
+ logger.error(f"Invalid jwt token: {e} ({credentials.credentials})")
57
+ raise UNAUTHORIZED_EXCEPTION
58
+ except jwt.InvalidTokenError as e:
59
+ logger.error(f"Invalid jwt token: {e} ({credentials.credentials})")
60
+ raise UNAUTHORIZED_EXCEPTION
@@ -0,0 +1,38 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any
3
+
4
+ from dynaconf.utils.boxing import DynaBox
5
+ from fastapi import HTTPException, Request, status
6
+ from fastapi.security import HTTPAuthorizationCredentials
7
+ from fastapi.security.http import HTTPBase
8
+ from pydantic import BaseModel
9
+
10
+ UNAUTHORIZED_EXCEPTION = HTTPException(
11
+ status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized."
12
+ )
13
+
14
+
15
+ class AuthIdentity(BaseModel):
16
+ subject: str
17
+ scopes: list[str] = []
18
+ metadata: dict[str, Any] = {}
19
+
20
+
21
+ class BaseHTTPAuthBackend(ABC):
22
+ def __init__(self, config: DynaBox):
23
+ self.config = config
24
+ self.scheme = self.init_scheme()
25
+
26
+ @abstractmethod
27
+ def init_scheme(self) -> HTTPBase:
28
+ raise NotImplementedError()
29
+
30
+ @abstractmethod
31
+ async def authenticate(self, credentials: HTTPAuthorizationCredentials) -> AuthIdentity | None:
32
+ raise NotImplementedError()
33
+
34
+ async def __call__(self, request: Request) -> AuthIdentity | None:
35
+ credentials = await self.scheme(request)
36
+ if not credentials:
37
+ return None
38
+ return await self.authenticate(credentials)