agenta 0.30.0a1__py3-none-any.whl → 0.30.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agenta might be problematic. Click here for more details.

Files changed (71) hide show
  1. agenta/__init__.py +1 -0
  2. agenta/client/backend/__init__.py +32 -3
  3. agenta/client/backend/access_control/__init__.py +1 -0
  4. agenta/client/backend/access_control/client.py +167 -0
  5. agenta/client/backend/apps/client.py +70 -10
  6. agenta/client/backend/client.py +61 -45
  7. agenta/client/backend/configs/client.py +6 -0
  8. agenta/client/backend/containers/client.py +6 -0
  9. agenta/client/backend/core/file.py +13 -8
  10. agenta/client/backend/environments/client.py +6 -0
  11. agenta/client/backend/evaluations/client.py +14 -1
  12. agenta/client/backend/evaluators/client.py +24 -0
  13. agenta/client/backend/observability/client.py +22 -16
  14. agenta/client/backend/observability_v_1/__init__.py +2 -2
  15. agenta/client/backend/observability_v_1/client.py +203 -0
  16. agenta/client/backend/observability_v_1/types/__init__.py +2 -1
  17. agenta/client/backend/observability_v_1/types/format.py +1 -1
  18. agenta/client/backend/observability_v_1/types/query_analytics_response.py +7 -0
  19. agenta/client/backend/scopes/__init__.py +1 -0
  20. agenta/client/backend/scopes/client.py +114 -0
  21. agenta/client/backend/testsets/client.py +305 -121
  22. agenta/client/backend/types/__init__.py +24 -2
  23. agenta/client/backend/types/analytics_response.py +24 -0
  24. agenta/client/backend/types/app.py +2 -1
  25. agenta/client/backend/types/body_import_testset.py +0 -1
  26. agenta/client/backend/types/bucket_dto.py +26 -0
  27. agenta/client/backend/types/header_dto.py +22 -0
  28. agenta/client/backend/types/legacy_analytics_response.py +29 -0
  29. agenta/client/backend/types/legacy_data_point.py +27 -0
  30. agenta/client/backend/types/metrics_dto.py +24 -0
  31. agenta/client/backend/types/permission.py +1 -0
  32. agenta/client/backend/types/projects_response.py +28 -0
  33. agenta/client/backend/types/provider_key_dto.py +23 -0
  34. agenta/client/backend/types/provider_kind.py +21 -0
  35. agenta/client/backend/types/secret_dto.py +24 -0
  36. agenta/client/backend/types/secret_kind.py +5 -0
  37. agenta/client/backend/types/secret_response_dto.py +27 -0
  38. agenta/client/backend/variants/client.py +66 -0
  39. agenta/client/backend/vault/__init__.py +1 -0
  40. agenta/client/backend/vault/client.py +685 -0
  41. agenta/client/client.py +1 -1
  42. agenta/sdk/__init__.py +1 -0
  43. agenta/sdk/agenta_init.py +47 -118
  44. agenta/sdk/assets.py +57 -46
  45. agenta/sdk/context/exporting.py +25 -0
  46. agenta/sdk/context/routing.py +12 -12
  47. agenta/sdk/context/tracing.py +26 -1
  48. agenta/sdk/decorators/routing.py +272 -267
  49. agenta/sdk/decorators/tracing.py +53 -31
  50. agenta/sdk/managers/config.py +8 -118
  51. agenta/sdk/managers/secrets.py +38 -0
  52. agenta/sdk/middleware/auth.py +128 -93
  53. agenta/sdk/middleware/cache.py +4 -0
  54. agenta/sdk/middleware/config.py +254 -0
  55. agenta/sdk/middleware/cors.py +27 -0
  56. agenta/sdk/middleware/otel.py +40 -0
  57. agenta/sdk/middleware/vault.py +158 -0
  58. agenta/sdk/tracing/exporters.py +40 -2
  59. agenta/sdk/tracing/inline.py +2 -2
  60. agenta/sdk/tracing/processors.py +11 -3
  61. agenta/sdk/tracing/tracing.py +14 -12
  62. agenta/sdk/utils/constants.py +1 -0
  63. agenta/sdk/utils/exceptions.py +20 -19
  64. agenta/sdk/utils/globals.py +4 -8
  65. agenta/sdk/utils/timing.py +58 -0
  66. {agenta-0.30.0a1.dist-info → agenta-0.30.0a3.dist-info}/METADATA +3 -2
  67. {agenta-0.30.0a1.dist-info → agenta-0.30.0a3.dist-info}/RECORD +69 -44
  68. {agenta-0.30.0a1.dist-info → agenta-0.30.0a3.dist-info}/WHEEL +1 -1
  69. agenta/client/backend/types/lm_providers_enum.py +0 -21
  70. agenta/sdk/tracing/context.py +0 -24
  71. {agenta-0.30.0a1.dist-info → agenta-0.30.0a3.dist-info}/entry_points.txt +0 -0
@@ -1,8 +1,12 @@
1
1
  from typing import Callable, Optional, Any, Dict, List, Union
2
+
2
3
  from functools import wraps
3
4
  from itertools import chain
4
5
  from inspect import iscoroutinefunction, getfullargspec
5
6
 
7
+ from opentelemetry import baggage as baggage
8
+ from opentelemetry.context import attach, detach
9
+
6
10
  from agenta.sdk.utils.exceptions import suppress
7
11
  from agenta.sdk.context.tracing import tracing_context
8
12
  from agenta.sdk.tracing.conventions import parse_span_kind
@@ -39,10 +43,12 @@ class instrument: # pylint: disable=invalid-name
39
43
  is_coroutine_function = iscoroutinefunction(func)
40
44
 
41
45
  @wraps(func)
42
- async def async_wrapper(*args, **kwargs):
43
- async def _async_auto_instrumented(*args, **kwargs):
46
+ async def awrapper(*args, **kwargs):
47
+ async def aauto_instrumented(*args, **kwargs):
44
48
  self._parse_type_and_kind()
45
49
 
50
+ token = self._attach_baggage()
51
+
46
52
  with ag.tracer.start_as_current_span(func.__name__, kind=self.kind):
47
53
  self._pre_instrument(func, *args, **kwargs)
48
54
 
@@ -52,13 +58,17 @@ class instrument: # pylint: disable=invalid-name
52
58
 
53
59
  return result
54
60
 
55
- return await _async_auto_instrumented(*args, **kwargs)
61
+ self._detach_baggage(token)
62
+
63
+ return await aauto_instrumented(*args, **kwargs)
56
64
 
57
65
  @wraps(func)
58
- def sync_wrapper(*args, **kwargs):
59
- def _sync_auto_instrumented(*args, **kwargs):
66
+ def wrapper(*args, **kwargs):
67
+ def auto_instrumented(*args, **kwargs):
60
68
  self._parse_type_and_kind()
61
69
 
70
+ token = self._attach_baggage()
71
+
62
72
  with ag.tracer.start_as_current_span(func.__name__, kind=self.kind):
63
73
  self._pre_instrument(func, *args, **kwargs)
64
74
 
@@ -68,9 +78,11 @@ class instrument: # pylint: disable=invalid-name
68
78
 
69
79
  return result
70
80
 
71
- return _sync_auto_instrumented(*args, **kwargs)
81
+ self._detach_baggage(token)
82
+
83
+ return auto_instrumented(*args, **kwargs)
72
84
 
73
- return async_wrapper if is_coroutine_function else sync_wrapper
85
+ return awrapper if is_coroutine_function else wrapper
74
86
 
75
87
  def _parse_type_and_kind(self):
76
88
  if not ag.tracing.get_current_span().is_recording():
@@ -78,6 +90,25 @@ class instrument: # pylint: disable=invalid-name
78
90
 
79
91
  self.kind = parse_span_kind(self.type)
80
92
 
93
+ def _attach_baggage(self):
94
+ context = tracing_context.get()
95
+
96
+ references = context.references
97
+
98
+ token = None
99
+ if references:
100
+ for k, v in references.items():
101
+ token = attach(baggage.set_baggage(f"ag.refs.{k}", v))
102
+
103
+ return token
104
+
105
+ def _detach_baggage(
106
+ self,
107
+ token,
108
+ ):
109
+ if token:
110
+ detach(token)
111
+
81
112
  def _pre_instrument(
82
113
  self,
83
114
  func,
@@ -86,29 +117,21 @@ class instrument: # pylint: disable=invalid-name
86
117
  ):
87
118
  span = ag.tracing.get_current_span()
88
119
 
120
+ context = tracing_context.get()
121
+
89
122
  with suppress():
123
+ trace_id = span.context.trace_id
124
+
125
+ ag.tracing.credentials[trace_id] = context.credentials
126
+
90
127
  span.set_attributes(
91
128
  attributes={"node": self.type},
92
129
  namespace="type",
93
130
  )
94
131
 
95
132
  if span.parent is None:
96
- rctx = tracing_context.get()
97
-
98
- span.set_attributes(
99
- attributes={"configuration": rctx.get("config", {})},
100
- namespace="meta",
101
- )
102
- span.set_attributes(
103
- attributes={"environment": rctx.get("environment", {})},
104
- namespace="meta",
105
- )
106
133
  span.set_attributes(
107
- attributes={"version": rctx.get("version", {})},
108
- namespace="meta",
109
- )
110
- span.set_attributes(
111
- attributes={"variant": rctx.get("variant", {})},
134
+ attributes={"configuration": context.parameters or {}},
112
135
  namespace="meta",
113
136
  )
114
137
 
@@ -118,6 +141,7 @@ class instrument: # pylint: disable=invalid-name
118
141
  io=self._parse(func, *args, **kwargs),
119
142
  ignore=self.ignore_inputs,
120
143
  )
144
+
121
145
  span.set_attributes(
122
146
  attributes={"inputs": _inputs},
123
147
  namespace="data",
@@ -161,6 +185,7 @@ class instrument: # pylint: disable=invalid-name
161
185
  io=self._patch(result),
162
186
  ignore=self.ignore_outputs,
163
187
  )
188
+
164
189
  span.set_attributes(
165
190
  attributes={"outputs": _outputs},
166
191
  namespace="data",
@@ -171,15 +196,12 @@ class instrument: # pylint: disable=invalid-name
171
196
 
172
197
  with suppress():
173
198
  if hasattr(span, "parent") and span.parent is None:
174
- tracing_context.set(
175
- tracing_context.get()
176
- | {
177
- "root": {
178
- "trace_id": span.get_span_context().trace_id,
179
- "span_id": span.get_span_context().span_id,
180
- }
181
- }
182
- )
199
+ context = tracing_context.get()
200
+ context.link = {
201
+ "tree_id": span.get_span_context().trace_id,
202
+ "node_id": span.get_span_context().span_id,
203
+ }
204
+ tracing_context.set(context)
183
205
 
184
206
  def _parse(
185
207
  self,
@@ -7,7 +7,7 @@ import yaml
7
7
  from pydantic import BaseModel
8
8
 
9
9
  from agenta.sdk.managers.shared import SharedManager
10
- from agenta.sdk.decorators.routing import routing_context
10
+ from agenta.sdk.context.routing import routing_context
11
11
 
12
12
  T = TypeVar("T", bound=BaseModel)
13
13
 
@@ -20,7 +20,7 @@ class ConfigManager:
20
20
  @staticmethod
21
21
  def get_from_route(
22
22
  schema: Optional[Type[T]] = None,
23
- ) -> Union[Dict[str, Any], T]:
23
+ ) -> Optional[Union[Dict[str, Any], T]]:
24
24
  """
25
25
  Retrieves the configuration from the route context and returns a config object.
26
26
 
@@ -47,125 +47,15 @@ class ConfigManager:
47
47
 
48
48
  context = routing_context.get()
49
49
 
50
- parameters = None
51
-
52
- if "config" in context and context["config"]:
53
- parameters = context["config"]
54
-
55
- else:
56
- app_id: Optional[str] = None
57
- app_slug: Optional[str] = None
58
- variant_id: Optional[str] = None
59
- variant_slug: Optional[str] = None
60
- variant_version: Optional[int] = None
61
- environment_id: Optional[str] = None
62
- environment_slug: Optional[str] = None
63
- environment_version: Optional[int] = None
64
-
65
- if "application" in context:
66
- app_id = context["application"].get("id")
67
- app_slug = context["application"].get("slug")
68
-
69
- if "variant" in context:
70
- variant_id = context["variant"].get("id")
71
- variant_slug = context["variant"].get("slug")
72
- variant_version = context["variant"].get("version")
73
-
74
- if "environment" in context:
75
- environment_id = context["environment"].get("id")
76
- environment_slug = context["environment"].get("slug")
77
- environment_version = context["environment"].get("version")
78
-
79
- parameters = ConfigManager.get_from_registry(
80
- app_id=app_id,
81
- app_slug=app_slug,
82
- variant_id=variant_id,
83
- variant_slug=variant_slug,
84
- variant_version=variant_version,
85
- environment_id=environment_id,
86
- environment_slug=environment_slug,
87
- environment_version=environment_version,
88
- )
50
+ parameters = context.parameters
89
51
 
90
- if schema:
91
- return schema(**parameters)
92
-
93
- return parameters
52
+ if not parameters:
53
+ return None
94
54
 
95
- @staticmethod
96
- async def aget_from_route(
97
- schema: Optional[Type[T]] = None,
98
- ) -> Union[Dict[str, Any], T]:
99
- """
100
- Asynchronously retrieves the configuration from the route context and returns a config object.
55
+ if not schema:
56
+ return parameters
101
57
 
102
- This method checks the route context for configuration information and returns
103
- an instance of the specified schema based on the available context data.
104
-
105
- Args:
106
- schema (Type[T]): A Pydantic model class that defines the structure of the configuration.
107
-
108
- Returns:
109
- T: An instance of the specified schema populated with the configuration data.
110
-
111
- Raises:
112
- ValueError: If conflicting configuration sources are provided or if no valid
113
- configuration source is found in the context.
114
-
115
- Note:
116
- The method prioritizes the inputs in the following way:
117
- 1. 'config' (i.e. when called explicitly from the playground)
118
- 2. 'environment'
119
- 3. 'variant'
120
- Only one of these should be provided.
121
- """
122
-
123
- context = routing_context.get()
124
-
125
- parameters = None
126
-
127
- if "config" in context and context["config"]:
128
- parameters = context["config"]
129
-
130
- else:
131
- app_id: Optional[str] = None
132
- app_slug: Optional[str] = None
133
- variant_id: Optional[str] = None
134
- variant_slug: Optional[str] = None
135
- variant_version: Optional[int] = None
136
- environment_id: Optional[str] = None
137
- environment_slug: Optional[str] = None
138
- environment_version: Optional[int] = None
139
-
140
- if "application" in context:
141
- app_id = context["application"].get("id")
142
- app_slug = context["application"].get("slug")
143
-
144
- if "variant" in context:
145
- variant_id = context["variant"].get("id")
146
- variant_slug = context["variant"].get("slug")
147
- variant_version = context["variant"].get("version")
148
-
149
- if "environment" in context:
150
- environment_id = context["environment"].get("id")
151
- environment_slug = context["environment"].get("slug")
152
- environment_version = context["environment"].get("version")
153
-
154
- parameters = await ConfigManager.async_get_from_registry(
155
- app_id=app_id,
156
- app_slug=app_slug,
157
- variant_id=variant_id,
158
- variant_slug=variant_slug,
159
- variant_version=variant_version,
160
- environment_id=environment_id,
161
- environment_slug=environment_slug,
162
- environment_version=environment_version,
163
- )
164
-
165
- if schema:
166
- return schema(**parameters)
167
-
168
- return parameters
58
+ return schema(**parameters)
169
59
 
170
60
  @staticmethod
171
61
  def get_from_registry(
@@ -0,0 +1,38 @@
1
+ from typing import Optional, Dict, Any
2
+
3
+ from agenta.sdk.context.routing import routing_context
4
+
5
+ from agenta.sdk.assets import model_to_provider_mapping
6
+
7
+
8
+ class SecretsManager:
9
+ @staticmethod
10
+ def get_from_route() -> Optional[Dict[str, Any]]:
11
+ context = routing_context.get()
12
+
13
+ secrets = context.secrets
14
+
15
+ if not secrets:
16
+ return None
17
+
18
+ return secrets
19
+
20
+ @staticmethod
21
+ def get_api_key_for_model(model: str) -> str:
22
+ secrets = SecretsManager.get_from_route()
23
+
24
+ if not secrets:
25
+ return None
26
+
27
+ provider = model_to_provider_mapping.get(model)
28
+
29
+ if not provider:
30
+ return None
31
+
32
+ provider = provider.lower().replace(" ", "")
33
+
34
+ for secret in secrets:
35
+ if secret["data"]["provider"] == provider:
36
+ return secret["data"]["key"]
37
+
38
+ return None
@@ -1,90 +1,116 @@
1
1
  from typing import Callable, Optional
2
- from os import environ
3
- from uuid import UUID
2
+
3
+ from os import getenv
4
4
  from json import dumps
5
- from traceback import format_exc
6
5
 
7
6
  import httpx
8
7
  from starlette.middleware.base import BaseHTTPMiddleware
9
- from fastapi import FastAPI, Request, Response
8
+ from fastapi import FastAPI, Request
9
+ from fastapi.responses import JSONResponse
10
10
 
11
- from agenta.sdk.utils.logging import log
12
- from agenta.sdk.middleware.cache import TTLLRUCache
11
+ from agenta.sdk.middleware.cache import TTLLRUCache, CACHE_CAPACITY, CACHE_TTL
12
+ from agenta.sdk.utils.constants import TRUTHY
13
+ from agenta.sdk.utils.exceptions import display_exception
13
14
 
14
- AGENTA_SDK_AUTH_CACHE_CAPACITY = environ.get(
15
- "AGENTA_SDK_AUTH_CACHE_CAPACITY",
16
- 512,
17
- )
15
+ import agenta as ag
18
16
 
19
- AGENTA_SDK_AUTH_CACHE_TTL = environ.get(
20
- "AGENTA_SDK_AUTH_CACHE_TTL",
21
- 15 * 60, # 15 minutes
22
- )
23
17
 
24
- AGENTA_SDK_AUTH_CACHE = str(environ.get("AGENTA_SDK_AUTH_CACHE", True)).lower() in (
25
- "true",
26
- "1",
27
- "t",
18
+ _SHARED_SERVICE = getenv("AGENTA_SHARED_SERVICE", "false").lower() in TRUTHY
19
+ _CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in TRUTHY
20
+ _UNAUTHORIZED_ALLOWED = (
21
+ getenv("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", "false").lower() in TRUTHY
28
22
  )
23
+ _ALWAYS_ALLOW_LIST = ["/health"]
29
24
 
30
- AGENTA_SDK_AUTH_CACHE = False
25
+ _cache = TTLLRUCache(capacity=CACHE_CAPACITY, ttl=CACHE_TTL)
31
26
 
32
- AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED = str(
33
- environ.get("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", False)
34
- ).lower() in ("true", "1", "t")
35
27
 
28
+ class DenyResponse(JSONResponse):
29
+ def __init__(
30
+ self,
31
+ status_code: int = 401,
32
+ detail: str = "Unauthorized",
33
+ ) -> None:
34
+ super().__init__(
35
+ status_code=status_code,
36
+ content={"detail": detail},
37
+ )
36
38
 
37
- class Deny(Response):
38
- def __init__(self) -> None:
39
- super().__init__(status_code=401, content="Unauthorized")
40
39
 
40
+ class DenyException(Exception):
41
+ def __init__(
42
+ self,
43
+ status_code: int = 401,
44
+ content: str = "Unauthorized",
45
+ ) -> None:
46
+ super().__init__()
41
47
 
42
- cache = TTLLRUCache(
43
- capacity=AGENTA_SDK_AUTH_CACHE_CAPACITY,
44
- ttl=AGENTA_SDK_AUTH_CACHE_TTL,
45
- )
48
+ self.status_code = status_code
49
+ self.content = content
46
50
 
47
51
 
48
- class AuthorizationMiddleware(BaseHTTPMiddleware):
49
- def __init__(
50
- self,
51
- app: FastAPI,
52
- host: str,
53
- resource_id: UUID,
54
- resource_type: str,
55
- ):
52
+ class AuthMiddleware(BaseHTTPMiddleware):
53
+ def __init__(self, app: FastAPI):
56
54
  super().__init__(app)
57
55
 
58
- self.host = host
59
- self.resource_id = resource_id
60
- self.resource_type = resource_type
56
+ self.host = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host
57
+ self.resource_id = (
58
+ ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.service_id
59
+ if not _SHARED_SERVICE
60
+ else None
61
+ )
62
+
63
+ async def dispatch(self, request: Request, call_next: Callable):
64
+ try:
65
+ if _UNAUTHORIZED_ALLOWED or request.url.path in _ALWAYS_ALLOW_LIST:
66
+ request.state.auth = {}
67
+
68
+ else:
69
+ credentials = await self._get_credentials(request)
70
+
71
+ request.state.auth = {"credentials": credentials}
61
72
 
62
- async def dispatch(
63
- self,
64
- request: Request,
65
- call_next: Callable,
66
- ):
67
- if AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED:
68
73
  return await call_next(request)
69
74
 
70
- try:
71
- authorization = (
72
- request.headers.get("Authorization")
73
- or request.headers.get("authorization")
74
- or None
75
+ except DenyException as deny:
76
+ display_exception("Auth Middleware Exception")
77
+
78
+ return DenyResponse(
79
+ status_code=deny.status_code,
80
+ detail=deny.content,
75
81
  )
76
82
 
83
+ except: # pylint: disable=bare-except
84
+ display_exception("Auth Middleware Exception")
85
+
86
+ return DenyResponse(
87
+ status_code=500,
88
+ detail="Auth: Unexpected Error.",
89
+ )
90
+
91
+ async def _get_credentials(self, request: Request) -> Optional[str]:
92
+ try:
93
+ authorization = request.headers.get("authorization", None)
94
+
77
95
  headers = {"Authorization": authorization} if authorization else None
78
96
 
79
- cookies = {"sAccessToken": request.cookies.get("sAccessToken")}
97
+ access_token = request.cookies.get("sAccessToken", None)
98
+
99
+ cookies = {"sAccessToken": access_token} if access_token else None
100
+
101
+ baggage = request.state.otel.get("baggage") if request.state.otel else {}
102
+
103
+ project_id = (
104
+ # CLEANEST
105
+ baggage.get("project_id")
106
+ # ALTERNATIVE
107
+ or request.query_params.get("project_id")
108
+ )
80
109
 
81
- params = {
82
- "action": "run_service",
83
- "resource_type": self.resource_type,
84
- "resource_id": self.resource_id,
85
- }
110
+ params = {"action": "run_service", "resource_type": "service"}
86
111
 
87
- project_id = request.query_params.get("project_id")
112
+ if self.resource_id:
113
+ params["resource_id"] = self.resource_id
88
114
 
89
115
  if project_id:
90
116
  params["project_id"] = project_id
@@ -98,48 +124,57 @@ class AuthorizationMiddleware(BaseHTTPMiddleware):
98
124
  sort_keys=True,
99
125
  )
100
126
 
101
- policy = None
102
- if AGENTA_SDK_AUTH_CACHE:
103
- policy = cache.get(_hash)
104
-
105
- if not policy:
106
- async with httpx.AsyncClient() as client:
107
- response = await client.get(
108
- f"{self.host}/api/permissions/verify",
109
- headers=headers,
110
- cookies=cookies,
111
- params=params,
112
- )
127
+ if _CACHE_ENABLED:
128
+ credentials = _cache.get(_hash)
129
+
130
+ if credentials:
131
+ return credentials
113
132
 
114
- if response.status_code != 200:
115
- cache.put(_hash, {"effect": "deny"})
116
- return Deny()
133
+ async with httpx.AsyncClient() as client:
134
+ response = await client.get(
135
+ f"{self.host}/api/permissions/verify",
136
+ headers=headers,
137
+ cookies=cookies,
138
+ params=params,
139
+ )
117
140
 
118
- auth = response.json()
141
+ if response.status_code == 401:
142
+ raise DenyException(
143
+ status_code=401,
144
+ content="Invalid credentials",
145
+ )
146
+ elif response.status_code == 403:
147
+ raise DenyException(
148
+ status_code=403,
149
+ content="Service execution not allowed.",
150
+ )
151
+ elif response.status_code != 200:
152
+ raise DenyException(
153
+ status_code=400,
154
+ content="Auth: Unexpected Error.",
155
+ )
119
156
 
120
- if auth.get("effect") != "allow":
121
- cache.put(_hash, {"effect": "deny"})
122
- return Deny()
157
+ auth = response.json()
123
158
 
124
- policy = {
125
- "effect": "allow",
126
- "credentials": auth.get("credentials"),
127
- }
159
+ if auth.get("effect") != "allow":
160
+ raise DenyException(
161
+ status_code=403,
162
+ content="Service execution not allowed.",
163
+ )
128
164
 
129
- cache.put(_hash, policy)
165
+ credentials = auth.get("credentials")
130
166
 
131
- if not policy or policy.get("effect") == "deny":
132
- return Deny()
167
+ _cache.put(_hash, credentials)
133
168
 
134
- request.state.credentials = policy.get("credentials")
169
+ return credentials
135
170
 
136
- return await call_next(request)
171
+ except DenyException as deny:
172
+ raise deny
137
173
 
138
- except: # pylint: disable=bare-except
139
- log.warning("------------------------------------------------------")
140
- log.warning("Agenta SDK - handling auth middleware exception below:")
141
- log.warning("------------------------------------------------------")
142
- log.warning(format_exc().strip("\n"))
143
- log.warning("------------------------------------------------------")
174
+ except Exception as exc: # pylint: disable=bare-except
175
+ display_exception("Auth Middleware Exception (suppressed)")
144
176
 
145
- return Deny()
177
+ raise DenyException(
178
+ status_code=500,
179
+ content="Auth: Unexpected Error.",
180
+ ) from exc
@@ -1,6 +1,10 @@
1
+ from os import getenv
1
2
  from time import time
2
3
  from collections import OrderedDict
3
4
 
5
+ CACHE_CAPACITY = int(getenv("AGENTA_MIDDLEWARE_CACHE_CAPACITY", "512"))
6
+ CACHE_TTL = int(getenv("AGENTA_MIDDLEWARE_CACHE_TTL", str(5 * 60))) # 5 minutes
7
+
4
8
 
5
9
  class TTLLRUCache:
6
10
  def __init__(self, capacity: int, ttl: int):