agenta 0.30.0a1__py3-none-any.whl → 0.30.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agenta might be problematic. Click here for more details.
- agenta/__init__.py +1 -0
- agenta/client/backend/__init__.py +32 -3
- agenta/client/backend/access_control/__init__.py +1 -0
- agenta/client/backend/access_control/client.py +167 -0
- agenta/client/backend/apps/client.py +70 -10
- agenta/client/backend/client.py +61 -45
- agenta/client/backend/configs/client.py +6 -0
- agenta/client/backend/containers/client.py +6 -0
- agenta/client/backend/core/file.py +13 -8
- agenta/client/backend/environments/client.py +6 -0
- agenta/client/backend/evaluations/client.py +14 -1
- agenta/client/backend/evaluators/client.py +24 -0
- agenta/client/backend/observability/client.py +22 -16
- agenta/client/backend/observability_v_1/__init__.py +2 -2
- agenta/client/backend/observability_v_1/client.py +203 -0
- agenta/client/backend/observability_v_1/types/__init__.py +2 -1
- agenta/client/backend/observability_v_1/types/format.py +1 -1
- agenta/client/backend/observability_v_1/types/query_analytics_response.py +7 -0
- agenta/client/backend/scopes/__init__.py +1 -0
- agenta/client/backend/scopes/client.py +114 -0
- agenta/client/backend/testsets/client.py +305 -121
- agenta/client/backend/types/__init__.py +24 -2
- agenta/client/backend/types/analytics_response.py +24 -0
- agenta/client/backend/types/app.py +2 -1
- agenta/client/backend/types/body_import_testset.py +0 -1
- agenta/client/backend/types/bucket_dto.py +26 -0
- agenta/client/backend/types/header_dto.py +22 -0
- agenta/client/backend/types/legacy_analytics_response.py +29 -0
- agenta/client/backend/types/legacy_data_point.py +27 -0
- agenta/client/backend/types/metrics_dto.py +24 -0
- agenta/client/backend/types/permission.py +1 -0
- agenta/client/backend/types/projects_response.py +28 -0
- agenta/client/backend/types/provider_key_dto.py +23 -0
- agenta/client/backend/types/provider_kind.py +21 -0
- agenta/client/backend/types/secret_dto.py +24 -0
- agenta/client/backend/types/secret_kind.py +5 -0
- agenta/client/backend/types/secret_response_dto.py +27 -0
- agenta/client/backend/variants/client.py +66 -0
- agenta/client/backend/vault/__init__.py +1 -0
- agenta/client/backend/vault/client.py +685 -0
- agenta/client/client.py +1 -1
- agenta/sdk/__init__.py +1 -0
- agenta/sdk/agenta_init.py +47 -118
- agenta/sdk/assets.py +57 -46
- agenta/sdk/context/exporting.py +25 -0
- agenta/sdk/context/routing.py +12 -12
- agenta/sdk/context/tracing.py +26 -1
- agenta/sdk/decorators/routing.py +272 -267
- agenta/sdk/decorators/tracing.py +53 -31
- agenta/sdk/managers/config.py +8 -118
- agenta/sdk/managers/secrets.py +38 -0
- agenta/sdk/middleware/auth.py +128 -93
- agenta/sdk/middleware/cache.py +4 -0
- agenta/sdk/middleware/config.py +254 -0
- agenta/sdk/middleware/cors.py +27 -0
- agenta/sdk/middleware/otel.py +40 -0
- agenta/sdk/middleware/vault.py +158 -0
- agenta/sdk/tracing/exporters.py +40 -2
- agenta/sdk/tracing/inline.py +2 -2
- agenta/sdk/tracing/processors.py +11 -3
- agenta/sdk/tracing/tracing.py +14 -12
- agenta/sdk/utils/constants.py +1 -0
- agenta/sdk/utils/exceptions.py +20 -19
- agenta/sdk/utils/globals.py +4 -8
- agenta/sdk/utils/timing.py +58 -0
- {agenta-0.30.0a1.dist-info → agenta-0.30.0a3.dist-info}/METADATA +3 -2
- {agenta-0.30.0a1.dist-info → agenta-0.30.0a3.dist-info}/RECORD +69 -44
- {agenta-0.30.0a1.dist-info → agenta-0.30.0a3.dist-info}/WHEEL +1 -1
- agenta/client/backend/types/lm_providers_enum.py +0 -21
- agenta/sdk/tracing/context.py +0 -24
- {agenta-0.30.0a1.dist-info → agenta-0.30.0a3.dist-info}/entry_points.txt +0 -0
agenta/sdk/decorators/tracing.py
CHANGED
|
@@ -1,8 +1,12 @@
|
|
|
1
1
|
from typing import Callable, Optional, Any, Dict, List, Union
|
|
2
|
+
|
|
2
3
|
from functools import wraps
|
|
3
4
|
from itertools import chain
|
|
4
5
|
from inspect import iscoroutinefunction, getfullargspec
|
|
5
6
|
|
|
7
|
+
from opentelemetry import baggage as baggage
|
|
8
|
+
from opentelemetry.context import attach, detach
|
|
9
|
+
|
|
6
10
|
from agenta.sdk.utils.exceptions import suppress
|
|
7
11
|
from agenta.sdk.context.tracing import tracing_context
|
|
8
12
|
from agenta.sdk.tracing.conventions import parse_span_kind
|
|
@@ -39,10 +43,12 @@ class instrument: # pylint: disable=invalid-name
|
|
|
39
43
|
is_coroutine_function = iscoroutinefunction(func)
|
|
40
44
|
|
|
41
45
|
@wraps(func)
|
|
42
|
-
async def
|
|
43
|
-
async def
|
|
46
|
+
async def awrapper(*args, **kwargs):
|
|
47
|
+
async def aauto_instrumented(*args, **kwargs):
|
|
44
48
|
self._parse_type_and_kind()
|
|
45
49
|
|
|
50
|
+
token = self._attach_baggage()
|
|
51
|
+
|
|
46
52
|
with ag.tracer.start_as_current_span(func.__name__, kind=self.kind):
|
|
47
53
|
self._pre_instrument(func, *args, **kwargs)
|
|
48
54
|
|
|
@@ -52,13 +58,17 @@ class instrument: # pylint: disable=invalid-name
|
|
|
52
58
|
|
|
53
59
|
return result
|
|
54
60
|
|
|
55
|
-
|
|
61
|
+
self._detach_baggage(token)
|
|
62
|
+
|
|
63
|
+
return await aauto_instrumented(*args, **kwargs)
|
|
56
64
|
|
|
57
65
|
@wraps(func)
|
|
58
|
-
def
|
|
59
|
-
def
|
|
66
|
+
def wrapper(*args, **kwargs):
|
|
67
|
+
def auto_instrumented(*args, **kwargs):
|
|
60
68
|
self._parse_type_and_kind()
|
|
61
69
|
|
|
70
|
+
token = self._attach_baggage()
|
|
71
|
+
|
|
62
72
|
with ag.tracer.start_as_current_span(func.__name__, kind=self.kind):
|
|
63
73
|
self._pre_instrument(func, *args, **kwargs)
|
|
64
74
|
|
|
@@ -68,9 +78,11 @@ class instrument: # pylint: disable=invalid-name
|
|
|
68
78
|
|
|
69
79
|
return result
|
|
70
80
|
|
|
71
|
-
|
|
81
|
+
self._detach_baggage(token)
|
|
82
|
+
|
|
83
|
+
return auto_instrumented(*args, **kwargs)
|
|
72
84
|
|
|
73
|
-
return
|
|
85
|
+
return awrapper if is_coroutine_function else wrapper
|
|
74
86
|
|
|
75
87
|
def _parse_type_and_kind(self):
|
|
76
88
|
if not ag.tracing.get_current_span().is_recording():
|
|
@@ -78,6 +90,25 @@ class instrument: # pylint: disable=invalid-name
|
|
|
78
90
|
|
|
79
91
|
self.kind = parse_span_kind(self.type)
|
|
80
92
|
|
|
93
|
+
def _attach_baggage(self):
|
|
94
|
+
context = tracing_context.get()
|
|
95
|
+
|
|
96
|
+
references = context.references
|
|
97
|
+
|
|
98
|
+
token = None
|
|
99
|
+
if references:
|
|
100
|
+
for k, v in references.items():
|
|
101
|
+
token = attach(baggage.set_baggage(f"ag.refs.{k}", v))
|
|
102
|
+
|
|
103
|
+
return token
|
|
104
|
+
|
|
105
|
+
def _detach_baggage(
|
|
106
|
+
self,
|
|
107
|
+
token,
|
|
108
|
+
):
|
|
109
|
+
if token:
|
|
110
|
+
detach(token)
|
|
111
|
+
|
|
81
112
|
def _pre_instrument(
|
|
82
113
|
self,
|
|
83
114
|
func,
|
|
@@ -86,29 +117,21 @@ class instrument: # pylint: disable=invalid-name
|
|
|
86
117
|
):
|
|
87
118
|
span = ag.tracing.get_current_span()
|
|
88
119
|
|
|
120
|
+
context = tracing_context.get()
|
|
121
|
+
|
|
89
122
|
with suppress():
|
|
123
|
+
trace_id = span.context.trace_id
|
|
124
|
+
|
|
125
|
+
ag.tracing.credentials[trace_id] = context.credentials
|
|
126
|
+
|
|
90
127
|
span.set_attributes(
|
|
91
128
|
attributes={"node": self.type},
|
|
92
129
|
namespace="type",
|
|
93
130
|
)
|
|
94
131
|
|
|
95
132
|
if span.parent is None:
|
|
96
|
-
rctx = tracing_context.get()
|
|
97
|
-
|
|
98
|
-
span.set_attributes(
|
|
99
|
-
attributes={"configuration": rctx.get("config", {})},
|
|
100
|
-
namespace="meta",
|
|
101
|
-
)
|
|
102
|
-
span.set_attributes(
|
|
103
|
-
attributes={"environment": rctx.get("environment", {})},
|
|
104
|
-
namespace="meta",
|
|
105
|
-
)
|
|
106
133
|
span.set_attributes(
|
|
107
|
-
attributes={"
|
|
108
|
-
namespace="meta",
|
|
109
|
-
)
|
|
110
|
-
span.set_attributes(
|
|
111
|
-
attributes={"variant": rctx.get("variant", {})},
|
|
134
|
+
attributes={"configuration": context.parameters or {}},
|
|
112
135
|
namespace="meta",
|
|
113
136
|
)
|
|
114
137
|
|
|
@@ -118,6 +141,7 @@ class instrument: # pylint: disable=invalid-name
|
|
|
118
141
|
io=self._parse(func, *args, **kwargs),
|
|
119
142
|
ignore=self.ignore_inputs,
|
|
120
143
|
)
|
|
144
|
+
|
|
121
145
|
span.set_attributes(
|
|
122
146
|
attributes={"inputs": _inputs},
|
|
123
147
|
namespace="data",
|
|
@@ -161,6 +185,7 @@ class instrument: # pylint: disable=invalid-name
|
|
|
161
185
|
io=self._patch(result),
|
|
162
186
|
ignore=self.ignore_outputs,
|
|
163
187
|
)
|
|
188
|
+
|
|
164
189
|
span.set_attributes(
|
|
165
190
|
attributes={"outputs": _outputs},
|
|
166
191
|
namespace="data",
|
|
@@ -171,15 +196,12 @@ class instrument: # pylint: disable=invalid-name
|
|
|
171
196
|
|
|
172
197
|
with suppress():
|
|
173
198
|
if hasattr(span, "parent") and span.parent is None:
|
|
174
|
-
tracing_context.
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
}
|
|
181
|
-
}
|
|
182
|
-
)
|
|
199
|
+
context = tracing_context.get()
|
|
200
|
+
context.link = {
|
|
201
|
+
"tree_id": span.get_span_context().trace_id,
|
|
202
|
+
"node_id": span.get_span_context().span_id,
|
|
203
|
+
}
|
|
204
|
+
tracing_context.set(context)
|
|
183
205
|
|
|
184
206
|
def _parse(
|
|
185
207
|
self,
|
agenta/sdk/managers/config.py
CHANGED
|
@@ -7,7 +7,7 @@ import yaml
|
|
|
7
7
|
from pydantic import BaseModel
|
|
8
8
|
|
|
9
9
|
from agenta.sdk.managers.shared import SharedManager
|
|
10
|
-
from agenta.sdk.
|
|
10
|
+
from agenta.sdk.context.routing import routing_context
|
|
11
11
|
|
|
12
12
|
T = TypeVar("T", bound=BaseModel)
|
|
13
13
|
|
|
@@ -20,7 +20,7 @@ class ConfigManager:
|
|
|
20
20
|
@staticmethod
|
|
21
21
|
def get_from_route(
|
|
22
22
|
schema: Optional[Type[T]] = None,
|
|
23
|
-
) -> Union[Dict[str, Any], T]:
|
|
23
|
+
) -> Optional[Union[Dict[str, Any], T]]:
|
|
24
24
|
"""
|
|
25
25
|
Retrieves the configuration from the route context and returns a config object.
|
|
26
26
|
|
|
@@ -47,125 +47,15 @@ class ConfigManager:
|
|
|
47
47
|
|
|
48
48
|
context = routing_context.get()
|
|
49
49
|
|
|
50
|
-
parameters =
|
|
51
|
-
|
|
52
|
-
if "config" in context and context["config"]:
|
|
53
|
-
parameters = context["config"]
|
|
54
|
-
|
|
55
|
-
else:
|
|
56
|
-
app_id: Optional[str] = None
|
|
57
|
-
app_slug: Optional[str] = None
|
|
58
|
-
variant_id: Optional[str] = None
|
|
59
|
-
variant_slug: Optional[str] = None
|
|
60
|
-
variant_version: Optional[int] = None
|
|
61
|
-
environment_id: Optional[str] = None
|
|
62
|
-
environment_slug: Optional[str] = None
|
|
63
|
-
environment_version: Optional[int] = None
|
|
64
|
-
|
|
65
|
-
if "application" in context:
|
|
66
|
-
app_id = context["application"].get("id")
|
|
67
|
-
app_slug = context["application"].get("slug")
|
|
68
|
-
|
|
69
|
-
if "variant" in context:
|
|
70
|
-
variant_id = context["variant"].get("id")
|
|
71
|
-
variant_slug = context["variant"].get("slug")
|
|
72
|
-
variant_version = context["variant"].get("version")
|
|
73
|
-
|
|
74
|
-
if "environment" in context:
|
|
75
|
-
environment_id = context["environment"].get("id")
|
|
76
|
-
environment_slug = context["environment"].get("slug")
|
|
77
|
-
environment_version = context["environment"].get("version")
|
|
78
|
-
|
|
79
|
-
parameters = ConfigManager.get_from_registry(
|
|
80
|
-
app_id=app_id,
|
|
81
|
-
app_slug=app_slug,
|
|
82
|
-
variant_id=variant_id,
|
|
83
|
-
variant_slug=variant_slug,
|
|
84
|
-
variant_version=variant_version,
|
|
85
|
-
environment_id=environment_id,
|
|
86
|
-
environment_slug=environment_slug,
|
|
87
|
-
environment_version=environment_version,
|
|
88
|
-
)
|
|
50
|
+
parameters = context.parameters
|
|
89
51
|
|
|
90
|
-
if
|
|
91
|
-
return
|
|
92
|
-
|
|
93
|
-
return parameters
|
|
52
|
+
if not parameters:
|
|
53
|
+
return None
|
|
94
54
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
schema: Optional[Type[T]] = None,
|
|
98
|
-
) -> Union[Dict[str, Any], T]:
|
|
99
|
-
"""
|
|
100
|
-
Asynchronously retrieves the configuration from the route context and returns a config object.
|
|
55
|
+
if not schema:
|
|
56
|
+
return parameters
|
|
101
57
|
|
|
102
|
-
|
|
103
|
-
an instance of the specified schema based on the available context data.
|
|
104
|
-
|
|
105
|
-
Args:
|
|
106
|
-
schema (Type[T]): A Pydantic model class that defines the structure of the configuration.
|
|
107
|
-
|
|
108
|
-
Returns:
|
|
109
|
-
T: An instance of the specified schema populated with the configuration data.
|
|
110
|
-
|
|
111
|
-
Raises:
|
|
112
|
-
ValueError: If conflicting configuration sources are provided or if no valid
|
|
113
|
-
configuration source is found in the context.
|
|
114
|
-
|
|
115
|
-
Note:
|
|
116
|
-
The method prioritizes the inputs in the following way:
|
|
117
|
-
1. 'config' (i.e. when called explicitly from the playground)
|
|
118
|
-
2. 'environment'
|
|
119
|
-
3. 'variant'
|
|
120
|
-
Only one of these should be provided.
|
|
121
|
-
"""
|
|
122
|
-
|
|
123
|
-
context = routing_context.get()
|
|
124
|
-
|
|
125
|
-
parameters = None
|
|
126
|
-
|
|
127
|
-
if "config" in context and context["config"]:
|
|
128
|
-
parameters = context["config"]
|
|
129
|
-
|
|
130
|
-
else:
|
|
131
|
-
app_id: Optional[str] = None
|
|
132
|
-
app_slug: Optional[str] = None
|
|
133
|
-
variant_id: Optional[str] = None
|
|
134
|
-
variant_slug: Optional[str] = None
|
|
135
|
-
variant_version: Optional[int] = None
|
|
136
|
-
environment_id: Optional[str] = None
|
|
137
|
-
environment_slug: Optional[str] = None
|
|
138
|
-
environment_version: Optional[int] = None
|
|
139
|
-
|
|
140
|
-
if "application" in context:
|
|
141
|
-
app_id = context["application"].get("id")
|
|
142
|
-
app_slug = context["application"].get("slug")
|
|
143
|
-
|
|
144
|
-
if "variant" in context:
|
|
145
|
-
variant_id = context["variant"].get("id")
|
|
146
|
-
variant_slug = context["variant"].get("slug")
|
|
147
|
-
variant_version = context["variant"].get("version")
|
|
148
|
-
|
|
149
|
-
if "environment" in context:
|
|
150
|
-
environment_id = context["environment"].get("id")
|
|
151
|
-
environment_slug = context["environment"].get("slug")
|
|
152
|
-
environment_version = context["environment"].get("version")
|
|
153
|
-
|
|
154
|
-
parameters = await ConfigManager.async_get_from_registry(
|
|
155
|
-
app_id=app_id,
|
|
156
|
-
app_slug=app_slug,
|
|
157
|
-
variant_id=variant_id,
|
|
158
|
-
variant_slug=variant_slug,
|
|
159
|
-
variant_version=variant_version,
|
|
160
|
-
environment_id=environment_id,
|
|
161
|
-
environment_slug=environment_slug,
|
|
162
|
-
environment_version=environment_version,
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
if schema:
|
|
166
|
-
return schema(**parameters)
|
|
167
|
-
|
|
168
|
-
return parameters
|
|
58
|
+
return schema(**parameters)
|
|
169
59
|
|
|
170
60
|
@staticmethod
|
|
171
61
|
def get_from_registry(
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from typing import Optional, Dict, Any
|
|
2
|
+
|
|
3
|
+
from agenta.sdk.context.routing import routing_context
|
|
4
|
+
|
|
5
|
+
from agenta.sdk.assets import model_to_provider_mapping
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SecretsManager:
|
|
9
|
+
@staticmethod
|
|
10
|
+
def get_from_route() -> Optional[Dict[str, Any]]:
|
|
11
|
+
context = routing_context.get()
|
|
12
|
+
|
|
13
|
+
secrets = context.secrets
|
|
14
|
+
|
|
15
|
+
if not secrets:
|
|
16
|
+
return None
|
|
17
|
+
|
|
18
|
+
return secrets
|
|
19
|
+
|
|
20
|
+
@staticmethod
|
|
21
|
+
def get_api_key_for_model(model: str) -> str:
|
|
22
|
+
secrets = SecretsManager.get_from_route()
|
|
23
|
+
|
|
24
|
+
if not secrets:
|
|
25
|
+
return None
|
|
26
|
+
|
|
27
|
+
provider = model_to_provider_mapping.get(model)
|
|
28
|
+
|
|
29
|
+
if not provider:
|
|
30
|
+
return None
|
|
31
|
+
|
|
32
|
+
provider = provider.lower().replace(" ", "")
|
|
33
|
+
|
|
34
|
+
for secret in secrets:
|
|
35
|
+
if secret["data"]["provider"] == provider:
|
|
36
|
+
return secret["data"]["key"]
|
|
37
|
+
|
|
38
|
+
return None
|
agenta/sdk/middleware/auth.py
CHANGED
|
@@ -1,90 +1,116 @@
|
|
|
1
1
|
from typing import Callable, Optional
|
|
2
|
-
|
|
3
|
-
from
|
|
2
|
+
|
|
3
|
+
from os import getenv
|
|
4
4
|
from json import dumps
|
|
5
|
-
from traceback import format_exc
|
|
6
5
|
|
|
7
6
|
import httpx
|
|
8
7
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
9
|
-
from fastapi import FastAPI, Request
|
|
8
|
+
from fastapi import FastAPI, Request
|
|
9
|
+
from fastapi.responses import JSONResponse
|
|
10
10
|
|
|
11
|
-
from agenta.sdk.
|
|
12
|
-
from agenta.sdk.
|
|
11
|
+
from agenta.sdk.middleware.cache import TTLLRUCache, CACHE_CAPACITY, CACHE_TTL
|
|
12
|
+
from agenta.sdk.utils.constants import TRUTHY
|
|
13
|
+
from agenta.sdk.utils.exceptions import display_exception
|
|
13
14
|
|
|
14
|
-
|
|
15
|
-
"AGENTA_SDK_AUTH_CACHE_CAPACITY",
|
|
16
|
-
512,
|
|
17
|
-
)
|
|
15
|
+
import agenta as ag
|
|
18
16
|
|
|
19
|
-
AGENTA_SDK_AUTH_CACHE_TTL = environ.get(
|
|
20
|
-
"AGENTA_SDK_AUTH_CACHE_TTL",
|
|
21
|
-
15 * 60, # 15 minutes
|
|
22
|
-
)
|
|
23
17
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
"
|
|
18
|
+
_SHARED_SERVICE = getenv("AGENTA_SHARED_SERVICE", "false").lower() in TRUTHY
|
|
19
|
+
_CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in TRUTHY
|
|
20
|
+
_UNAUTHORIZED_ALLOWED = (
|
|
21
|
+
getenv("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", "false").lower() in TRUTHY
|
|
28
22
|
)
|
|
23
|
+
_ALWAYS_ALLOW_LIST = ["/health"]
|
|
29
24
|
|
|
30
|
-
|
|
25
|
+
_cache = TTLLRUCache(capacity=CACHE_CAPACITY, ttl=CACHE_TTL)
|
|
31
26
|
|
|
32
|
-
AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED = str(
|
|
33
|
-
environ.get("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", False)
|
|
34
|
-
).lower() in ("true", "1", "t")
|
|
35
27
|
|
|
28
|
+
class DenyResponse(JSONResponse):
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
status_code: int = 401,
|
|
32
|
+
detail: str = "Unauthorized",
|
|
33
|
+
) -> None:
|
|
34
|
+
super().__init__(
|
|
35
|
+
status_code=status_code,
|
|
36
|
+
content={"detail": detail},
|
|
37
|
+
)
|
|
36
38
|
|
|
37
|
-
class Deny(Response):
|
|
38
|
-
def __init__(self) -> None:
|
|
39
|
-
super().__init__(status_code=401, content="Unauthorized")
|
|
40
39
|
|
|
40
|
+
class DenyException(Exception):
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
status_code: int = 401,
|
|
44
|
+
content: str = "Unauthorized",
|
|
45
|
+
) -> None:
|
|
46
|
+
super().__init__()
|
|
41
47
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
ttl=AGENTA_SDK_AUTH_CACHE_TTL,
|
|
45
|
-
)
|
|
48
|
+
self.status_code = status_code
|
|
49
|
+
self.content = content
|
|
46
50
|
|
|
47
51
|
|
|
48
|
-
class
|
|
49
|
-
def __init__(
|
|
50
|
-
self,
|
|
51
|
-
app: FastAPI,
|
|
52
|
-
host: str,
|
|
53
|
-
resource_id: UUID,
|
|
54
|
-
resource_type: str,
|
|
55
|
-
):
|
|
52
|
+
class AuthMiddleware(BaseHTTPMiddleware):
|
|
53
|
+
def __init__(self, app: FastAPI):
|
|
56
54
|
super().__init__(app)
|
|
57
55
|
|
|
58
|
-
self.host = host
|
|
59
|
-
self.resource_id =
|
|
60
|
-
|
|
56
|
+
self.host = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host
|
|
57
|
+
self.resource_id = (
|
|
58
|
+
ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.service_id
|
|
59
|
+
if not _SHARED_SERVICE
|
|
60
|
+
else None
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
async def dispatch(self, request: Request, call_next: Callable):
|
|
64
|
+
try:
|
|
65
|
+
if _UNAUTHORIZED_ALLOWED or request.url.path in _ALWAYS_ALLOW_LIST:
|
|
66
|
+
request.state.auth = {}
|
|
67
|
+
|
|
68
|
+
else:
|
|
69
|
+
credentials = await self._get_credentials(request)
|
|
70
|
+
|
|
71
|
+
request.state.auth = {"credentials": credentials}
|
|
61
72
|
|
|
62
|
-
async def dispatch(
|
|
63
|
-
self,
|
|
64
|
-
request: Request,
|
|
65
|
-
call_next: Callable,
|
|
66
|
-
):
|
|
67
|
-
if AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED:
|
|
68
73
|
return await call_next(request)
|
|
69
74
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
+
except DenyException as deny:
|
|
76
|
+
display_exception("Auth Middleware Exception")
|
|
77
|
+
|
|
78
|
+
return DenyResponse(
|
|
79
|
+
status_code=deny.status_code,
|
|
80
|
+
detail=deny.content,
|
|
75
81
|
)
|
|
76
82
|
|
|
83
|
+
except: # pylint: disable=bare-except
|
|
84
|
+
display_exception("Auth Middleware Exception")
|
|
85
|
+
|
|
86
|
+
return DenyResponse(
|
|
87
|
+
status_code=500,
|
|
88
|
+
detail="Auth: Unexpected Error.",
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
async def _get_credentials(self, request: Request) -> Optional[str]:
|
|
92
|
+
try:
|
|
93
|
+
authorization = request.headers.get("authorization", None)
|
|
94
|
+
|
|
77
95
|
headers = {"Authorization": authorization} if authorization else None
|
|
78
96
|
|
|
79
|
-
|
|
97
|
+
access_token = request.cookies.get("sAccessToken", None)
|
|
98
|
+
|
|
99
|
+
cookies = {"sAccessToken": access_token} if access_token else None
|
|
100
|
+
|
|
101
|
+
baggage = request.state.otel.get("baggage") if request.state.otel else {}
|
|
102
|
+
|
|
103
|
+
project_id = (
|
|
104
|
+
# CLEANEST
|
|
105
|
+
baggage.get("project_id")
|
|
106
|
+
# ALTERNATIVE
|
|
107
|
+
or request.query_params.get("project_id")
|
|
108
|
+
)
|
|
80
109
|
|
|
81
|
-
params = {
|
|
82
|
-
"action": "run_service",
|
|
83
|
-
"resource_type": self.resource_type,
|
|
84
|
-
"resource_id": self.resource_id,
|
|
85
|
-
}
|
|
110
|
+
params = {"action": "run_service", "resource_type": "service"}
|
|
86
111
|
|
|
87
|
-
|
|
112
|
+
if self.resource_id:
|
|
113
|
+
params["resource_id"] = self.resource_id
|
|
88
114
|
|
|
89
115
|
if project_id:
|
|
90
116
|
params["project_id"] = project_id
|
|
@@ -98,48 +124,57 @@ class AuthorizationMiddleware(BaseHTTPMiddleware):
|
|
|
98
124
|
sort_keys=True,
|
|
99
125
|
)
|
|
100
126
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
async with httpx.AsyncClient() as client:
|
|
107
|
-
response = await client.get(
|
|
108
|
-
f"{self.host}/api/permissions/verify",
|
|
109
|
-
headers=headers,
|
|
110
|
-
cookies=cookies,
|
|
111
|
-
params=params,
|
|
112
|
-
)
|
|
127
|
+
if _CACHE_ENABLED:
|
|
128
|
+
credentials = _cache.get(_hash)
|
|
129
|
+
|
|
130
|
+
if credentials:
|
|
131
|
+
return credentials
|
|
113
132
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
133
|
+
async with httpx.AsyncClient() as client:
|
|
134
|
+
response = await client.get(
|
|
135
|
+
f"{self.host}/api/permissions/verify",
|
|
136
|
+
headers=headers,
|
|
137
|
+
cookies=cookies,
|
|
138
|
+
params=params,
|
|
139
|
+
)
|
|
117
140
|
|
|
118
|
-
|
|
141
|
+
if response.status_code == 401:
|
|
142
|
+
raise DenyException(
|
|
143
|
+
status_code=401,
|
|
144
|
+
content="Invalid credentials",
|
|
145
|
+
)
|
|
146
|
+
elif response.status_code == 403:
|
|
147
|
+
raise DenyException(
|
|
148
|
+
status_code=403,
|
|
149
|
+
content="Service execution not allowed.",
|
|
150
|
+
)
|
|
151
|
+
elif response.status_code != 200:
|
|
152
|
+
raise DenyException(
|
|
153
|
+
status_code=400,
|
|
154
|
+
content="Auth: Unexpected Error.",
|
|
155
|
+
)
|
|
119
156
|
|
|
120
|
-
|
|
121
|
-
cache.put(_hash, {"effect": "deny"})
|
|
122
|
-
return Deny()
|
|
157
|
+
auth = response.json()
|
|
123
158
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
159
|
+
if auth.get("effect") != "allow":
|
|
160
|
+
raise DenyException(
|
|
161
|
+
status_code=403,
|
|
162
|
+
content="Service execution not allowed.",
|
|
163
|
+
)
|
|
128
164
|
|
|
129
|
-
|
|
165
|
+
credentials = auth.get("credentials")
|
|
130
166
|
|
|
131
|
-
|
|
132
|
-
return Deny()
|
|
167
|
+
_cache.put(_hash, credentials)
|
|
133
168
|
|
|
134
|
-
|
|
169
|
+
return credentials
|
|
135
170
|
|
|
136
|
-
|
|
171
|
+
except DenyException as deny:
|
|
172
|
+
raise deny
|
|
137
173
|
|
|
138
|
-
except: # pylint: disable=bare-except
|
|
139
|
-
|
|
140
|
-
log.warning("Agenta SDK - handling auth middleware exception below:")
|
|
141
|
-
log.warning("------------------------------------------------------")
|
|
142
|
-
log.warning(format_exc().strip("\n"))
|
|
143
|
-
log.warning("------------------------------------------------------")
|
|
174
|
+
except Exception as exc: # pylint: disable=bare-except
|
|
175
|
+
display_exception("Auth Middleware Exception (suppressed)")
|
|
144
176
|
|
|
145
|
-
|
|
177
|
+
raise DenyException(
|
|
178
|
+
status_code=500,
|
|
179
|
+
content="Auth: Unexpected Error.",
|
|
180
|
+
) from exc
|
agenta/sdk/middleware/cache.py
CHANGED
|
@@ -1,6 +1,10 @@
|
|
|
1
|
+
from os import getenv
|
|
1
2
|
from time import time
|
|
2
3
|
from collections import OrderedDict
|
|
3
4
|
|
|
5
|
+
CACHE_CAPACITY = int(getenv("AGENTA_MIDDLEWARE_CACHE_CAPACITY", "512"))
|
|
6
|
+
CACHE_TTL = int(getenv("AGENTA_MIDDLEWARE_CACHE_TTL", str(5 * 60))) # 5 minutes
|
|
7
|
+
|
|
4
8
|
|
|
5
9
|
class TTLLRUCache:
|
|
6
10
|
def __init__(self, capacity: int, ttl: int):
|