agenta 0.32.0__py3-none-any.whl → 0.32.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agenta might be problematic. Click here for more details.
- agenta/__init__.py +3 -1
- agenta/client/backend/client.py +22 -14
- agenta/client/backend/core/http_client.py +3 -3
- agenta/sdk/__init__.py +1 -1
- agenta/sdk/context/routing.py +1 -0
- agenta/sdk/decorators/routing.py +164 -476
- agenta/sdk/decorators/tracing.py +16 -4
- agenta/sdk/litellm/litellm.py +44 -8
- agenta/sdk/litellm/mockllm.py +27 -0
- agenta/sdk/litellm/mocks/__init__.py +32 -0
- agenta/sdk/managers/vault.py +16 -0
- agenta/sdk/middleware/auth.py +5 -1
- agenta/sdk/middleware/config.py +16 -7
- agenta/sdk/middleware/inline.py +38 -0
- agenta/sdk/middleware/mock.py +33 -0
- agenta/sdk/middleware/vault.py +6 -19
- agenta/sdk/tracing/exporters.py +0 -1
- agenta/sdk/tracing/inline.py +23 -29
- agenta/sdk/types.py +334 -4
- {agenta-0.32.0.dist-info → agenta-0.32.0a2.dist-info}/METADATA +1 -1
- {agenta-0.32.0.dist-info → agenta-0.32.0a2.dist-info}/RECORD +23 -18
- {agenta-0.32.0.dist-info → agenta-0.32.0a2.dist-info}/WHEEL +0 -0
- {agenta-0.32.0.dist-info → agenta-0.32.0a2.dist-info}/entry_points.txt +0 -0
agenta/sdk/decorators/tracing.py
CHANGED
|
@@ -165,15 +165,27 @@ class instrument: # pylint: disable=invalid-name
|
|
|
165
165
|
usage = {"total_tokens": usage}
|
|
166
166
|
|
|
167
167
|
span.set_attributes(
|
|
168
|
-
attributes={"total": cost},
|
|
168
|
+
attributes={"total": float(cost) if cost else None},
|
|
169
169
|
namespace="metrics.unit.costs",
|
|
170
170
|
)
|
|
171
171
|
span.set_attributes(
|
|
172
172
|
attributes=(
|
|
173
173
|
{
|
|
174
|
-
"prompt":
|
|
175
|
-
|
|
176
|
-
|
|
174
|
+
"prompt": (
|
|
175
|
+
float(usage.get("prompt_tokens"))
|
|
176
|
+
if usage.get("prompt_tokens", None)
|
|
177
|
+
else None
|
|
178
|
+
),
|
|
179
|
+
"completion": (
|
|
180
|
+
float(usage.get("completion_tokens"))
|
|
181
|
+
if usage.get("completion_tokens", None)
|
|
182
|
+
else None
|
|
183
|
+
),
|
|
184
|
+
"total": (
|
|
185
|
+
float(usage.get("total_tokens", None))
|
|
186
|
+
if usage.get("total_tokens", None)
|
|
187
|
+
else None
|
|
188
|
+
),
|
|
177
189
|
}
|
|
178
190
|
),
|
|
179
191
|
namespace="metrics.unit.tokens",
|
agenta/sdk/litellm/litellm.py
CHANGED
|
@@ -154,16 +154,34 @@ def litellm_handler():
|
|
|
154
154
|
pass
|
|
155
155
|
|
|
156
156
|
span.set_attributes(
|
|
157
|
-
attributes={
|
|
157
|
+
attributes={
|
|
158
|
+
"total": (
|
|
159
|
+
float(kwargs.get("response_cost"))
|
|
160
|
+
if kwargs.get("response_cost")
|
|
161
|
+
else None
|
|
162
|
+
)
|
|
163
|
+
},
|
|
158
164
|
namespace="metrics.unit.costs",
|
|
159
165
|
)
|
|
160
166
|
|
|
161
167
|
span.set_attributes(
|
|
162
168
|
attributes=(
|
|
163
169
|
{
|
|
164
|
-
"prompt":
|
|
165
|
-
|
|
166
|
-
|
|
170
|
+
"prompt": (
|
|
171
|
+
float(response_obj.usage.prompt_tokens)
|
|
172
|
+
if response_obj.usage.prompt_tokens
|
|
173
|
+
else None
|
|
174
|
+
),
|
|
175
|
+
"completion": (
|
|
176
|
+
float(response_obj.usage.completion_tokens)
|
|
177
|
+
if response_obj.usage.completion_tokens
|
|
178
|
+
else None
|
|
179
|
+
),
|
|
180
|
+
"total": (
|
|
181
|
+
float(response_obj.usage.total_tokens)
|
|
182
|
+
if response_obj.usage.total_tokens
|
|
183
|
+
else None
|
|
184
|
+
),
|
|
167
185
|
}
|
|
168
186
|
),
|
|
169
187
|
namespace="metrics.unit.tokens",
|
|
@@ -264,16 +282,34 @@ def litellm_handler():
|
|
|
264
282
|
pass
|
|
265
283
|
|
|
266
284
|
span.set_attributes(
|
|
267
|
-
attributes={
|
|
285
|
+
attributes={
|
|
286
|
+
"total": (
|
|
287
|
+
float(kwargs.get("response_cost"))
|
|
288
|
+
if kwargs.get("response_cost")
|
|
289
|
+
else None
|
|
290
|
+
)
|
|
291
|
+
},
|
|
268
292
|
namespace="metrics.unit.costs",
|
|
269
293
|
)
|
|
270
294
|
|
|
271
295
|
span.set_attributes(
|
|
272
296
|
attributes=(
|
|
273
297
|
{
|
|
274
|
-
"prompt":
|
|
275
|
-
|
|
276
|
-
|
|
298
|
+
"prompt": (
|
|
299
|
+
float(response_obj.usage.prompt_tokens)
|
|
300
|
+
if response_obj.usage.prompt_tokens
|
|
301
|
+
else None
|
|
302
|
+
),
|
|
303
|
+
"completion": (
|
|
304
|
+
float(response_obj.usage.completion_tokens)
|
|
305
|
+
if response_obj.usage.completion_tokens
|
|
306
|
+
else None
|
|
307
|
+
),
|
|
308
|
+
"total": (
|
|
309
|
+
float(response_obj.usage.total_tokens)
|
|
310
|
+
if response_obj.usage.total_tokens
|
|
311
|
+
else None
|
|
312
|
+
),
|
|
277
313
|
}
|
|
278
314
|
),
|
|
279
315
|
namespace="metrics.unit.tokens",
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from typing import Optional, Protocol, Any
|
|
2
|
+
|
|
3
|
+
from agenta.sdk.litellm.mocks import MOCKS
|
|
4
|
+
from agenta.sdk.context.routing import routing_context
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LitellmProtocol(Protocol):
|
|
8
|
+
async def acompletion(self, *args: Any, **kwargs: Any) -> Any:
|
|
9
|
+
...
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
litellm: Optional[LitellmProtocol] = None # pylint: disable=invalid-name
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
async def acompletion(*args, **kwargs):
|
|
16
|
+
mock = routing_context.get().mock
|
|
17
|
+
|
|
18
|
+
if mock:
|
|
19
|
+
if mock not in MOCKS:
|
|
20
|
+
raise ValueError(f"Mock {mock} not found")
|
|
21
|
+
|
|
22
|
+
return MOCKS[mock](*args, **kwargs)
|
|
23
|
+
|
|
24
|
+
if not litellm:
|
|
25
|
+
raise ValueError("litellm not found")
|
|
26
|
+
|
|
27
|
+
return await litellm.acompletion(*args, **kwargs)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MockMessageModel(BaseModel):
|
|
7
|
+
content: str
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MockChoiceModel(BaseModel):
|
|
11
|
+
message: MockMessageModel
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MockResponseModel(BaseModel):
|
|
15
|
+
choices: list[MockChoiceModel]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def hello_mock_response(*args, **kwargs) -> MockResponseModel:
|
|
19
|
+
return MockResponseModel(
|
|
20
|
+
choices=[
|
|
21
|
+
MockChoiceModel(
|
|
22
|
+
message=MockMessageModel(
|
|
23
|
+
content="world",
|
|
24
|
+
)
|
|
25
|
+
)
|
|
26
|
+
],
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
MOCKS: dict[str, Callable[..., MockResponseModel]] = {
|
|
31
|
+
"hello": hello_mock_response,
|
|
32
|
+
}
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from typing import Optional, Dict, Any
|
|
2
|
+
|
|
3
|
+
from agenta.sdk.context.routing import routing_context
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class VaultManager:
|
|
7
|
+
@staticmethod
|
|
8
|
+
def get_from_route() -> Optional[Dict[str, Any]]:
|
|
9
|
+
context = routing_context.get()
|
|
10
|
+
|
|
11
|
+
secrets = context.secrets
|
|
12
|
+
|
|
13
|
+
if not secrets:
|
|
14
|
+
return None
|
|
15
|
+
|
|
16
|
+
return secrets
|
agenta/sdk/middleware/auth.py
CHANGED
|
@@ -98,7 +98,11 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|
|
98
98
|
|
|
99
99
|
cookies = {"sAccessToken": access_token} if access_token else None
|
|
100
100
|
|
|
101
|
-
baggage =
|
|
101
|
+
baggage = (
|
|
102
|
+
getattr(request.state.otel, "baggage")
|
|
103
|
+
if hasattr(request.state, "otel")
|
|
104
|
+
else {}
|
|
105
|
+
)
|
|
102
106
|
|
|
103
107
|
project_id = (
|
|
104
108
|
# CLEANEST
|
agenta/sdk/middleware/config.py
CHANGED
|
@@ -40,7 +40,7 @@ class ConfigMiddleware(BaseHTTPMiddleware):
|
|
|
40
40
|
request: Request,
|
|
41
41
|
call_next: Callable,
|
|
42
42
|
):
|
|
43
|
-
request.state.config = {}
|
|
43
|
+
request.state.config = {"parameters": None, "references": None}
|
|
44
44
|
|
|
45
45
|
with suppress():
|
|
46
46
|
parameters, references = await self._get_config(request)
|
|
@@ -116,13 +116,14 @@ class ConfigMiddleware(BaseHTTPMiddleware):
|
|
|
116
116
|
|
|
117
117
|
for ref_key in ["application_ref", "variant_ref", "environment_ref"]:
|
|
118
118
|
refs = config.get(ref_key)
|
|
119
|
-
|
|
119
|
+
if refs:
|
|
120
|
+
ref_prefix = ref_key.split("_", maxsplit=1)[0]
|
|
120
121
|
|
|
121
|
-
|
|
122
|
-
|
|
122
|
+
for ref_part_key in ["id", "slug", "version"]:
|
|
123
|
+
ref_part = refs.get(ref_part_key)
|
|
123
124
|
|
|
124
|
-
|
|
125
|
-
|
|
125
|
+
if ref_part:
|
|
126
|
+
references[ref_prefix + "." + ref_part_key] = str(ref_part)
|
|
126
127
|
|
|
127
128
|
_cache.put(_hash, {"parameters": parameters, "references": references})
|
|
128
129
|
|
|
@@ -159,12 +160,20 @@ class ConfigMiddleware(BaseHTTPMiddleware):
|
|
|
159
160
|
or body.get("app")
|
|
160
161
|
)
|
|
161
162
|
|
|
162
|
-
|
|
163
|
+
application_version = (
|
|
164
|
+
# CLEANEST
|
|
165
|
+
baggage.get("application_version")
|
|
166
|
+
# ALTERNATIVE
|
|
167
|
+
or request.query_params.get("application_version")
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
if not any([application_id, application_slug, application_version]):
|
|
163
171
|
return None
|
|
164
172
|
|
|
165
173
|
return Reference(
|
|
166
174
|
id=application_id,
|
|
167
175
|
slug=application_slug,
|
|
176
|
+
version=application_version,
|
|
168
177
|
)
|
|
169
178
|
|
|
170
179
|
async def _parse_variant_ref(
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
|
|
3
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
4
|
+
from fastapi import Request, FastAPI
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
from agenta.sdk.utils.exceptions import suppress
|
|
8
|
+
|
|
9
|
+
from agenta.sdk.utils.constants import TRUTHY
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class InlineMiddleware(BaseHTTPMiddleware):
|
|
13
|
+
def __init__(self, app: FastAPI):
|
|
14
|
+
super().__init__(app)
|
|
15
|
+
|
|
16
|
+
async def dispatch(
|
|
17
|
+
self,
|
|
18
|
+
request: Request,
|
|
19
|
+
call_next: Callable,
|
|
20
|
+
):
|
|
21
|
+
request.state.inline = False
|
|
22
|
+
|
|
23
|
+
with suppress():
|
|
24
|
+
baggage = request.state.otel.get("baggage") if request.state.otel else {}
|
|
25
|
+
|
|
26
|
+
inline = (
|
|
27
|
+
str(
|
|
28
|
+
# CLEANEST
|
|
29
|
+
baggage.get("inline")
|
|
30
|
+
# ALTERNATIVE
|
|
31
|
+
or request.query_params.get("inline")
|
|
32
|
+
)
|
|
33
|
+
in TRUTHY
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
request.state.inline = inline
|
|
37
|
+
|
|
38
|
+
return await call_next(request)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
|
|
3
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
4
|
+
from fastapi import Request, FastAPI
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
from agenta.sdk.utils.exceptions import suppress
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MockMiddleware(BaseHTTPMiddleware):
|
|
11
|
+
def __init__(self, app: FastAPI):
|
|
12
|
+
super().__init__(app)
|
|
13
|
+
|
|
14
|
+
async def dispatch(
|
|
15
|
+
self,
|
|
16
|
+
request: Request,
|
|
17
|
+
call_next: Callable,
|
|
18
|
+
):
|
|
19
|
+
request.state.mock = None
|
|
20
|
+
|
|
21
|
+
with suppress():
|
|
22
|
+
baggage = request.state.otel.get("baggage") if request.state.otel else {}
|
|
23
|
+
|
|
24
|
+
mock = (
|
|
25
|
+
# CLEANEST
|
|
26
|
+
baggage.get("mock")
|
|
27
|
+
# ALTERNATIVE
|
|
28
|
+
or request.query_params.get("mock")
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
request.state.mock = mock
|
|
32
|
+
|
|
33
|
+
return await call_next(request)
|
agenta/sdk/middleware/vault.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from os import getenv
|
|
2
2
|
from json import dumps
|
|
3
|
-
from typing import Callable, Dict, Optional, List, Any
|
|
3
|
+
from typing import Callable, Dict, Optional, List, Any, get_args
|
|
4
4
|
|
|
5
5
|
import httpx
|
|
6
6
|
from fastapi import FastAPI, Request
|
|
@@ -18,25 +18,11 @@ from agenta.sdk.middleware.cache import TTLLRUCache, CACHE_CAPACITY, CACHE_TTL
|
|
|
18
18
|
import agenta as ag
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
|
|
22
|
-
# for a fixed set of string literals representing various provider names, alongside `typing.Any`.
|
|
23
|
-
PROVIDER_KINDS = []
|
|
21
|
+
_PROVIDER_KINDS = []
|
|
24
22
|
|
|
25
|
-
# Rationale behind the following:
|
|
26
|
-
# -------------------------------
|
|
27
|
-
# You cannot loop directly over the values in `typing.Literal` because:
|
|
28
|
-
# - `Literal` is not iterable.
|
|
29
|
-
# - `ProviderKind.__args__` includes `Literal` and `Any`, but the actual string values
|
|
30
|
-
# are nested within the `Literal`'s own `__args__` attribute.
|
|
31
|
-
|
|
32
|
-
# To solve this, we programmatically extract the values from `Literal` while retaining
|
|
33
|
-
# the structure of ProviderKind. This ensures:
|
|
34
|
-
# 1. We don't modify the original `ProviderKind` type definition.
|
|
35
|
-
# 2. We dynamically access the literal values for use at runtime when necessary.
|
|
36
23
|
for arg in ProviderKind.__args__: # type: ignore
|
|
37
24
|
if hasattr(arg, "__args__"):
|
|
38
|
-
|
|
39
|
-
|
|
25
|
+
_PROVIDER_KINDS.extend(arg.__args__)
|
|
40
26
|
|
|
41
27
|
_CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in TRUTHY
|
|
42
28
|
|
|
@@ -100,7 +86,7 @@ class VaultMiddleware(BaseHTTPMiddleware):
|
|
|
100
86
|
local_secrets: List[SecretDTO] = []
|
|
101
87
|
|
|
102
88
|
try:
|
|
103
|
-
for provider_kind in
|
|
89
|
+
for provider_kind in _PROVIDER_KINDS:
|
|
104
90
|
provider = provider_kind
|
|
105
91
|
key_name = f"{provider.upper()}_API_KEY"
|
|
106
92
|
key = getenv(key_name)
|
|
@@ -108,7 +94,8 @@ class VaultMiddleware(BaseHTTPMiddleware):
|
|
|
108
94
|
if not key:
|
|
109
95
|
continue
|
|
110
96
|
|
|
111
|
-
secret = SecretDTO(
|
|
97
|
+
secret = SecretDTO(
|
|
98
|
+
# kind=... # defaults to 'provider_kind'
|
|
112
99
|
data=ProviderKeyDTO(
|
|
113
100
|
provider=provider,
|
|
114
101
|
key=key,
|
agenta/sdk/tracing/exporters.py
CHANGED
agenta/sdk/tracing/inline.py
CHANGED
|
@@ -701,15 +701,6 @@ def _parse_from_semconv(
|
|
|
701
701
|
def _parse_from_links(
|
|
702
702
|
otel_span_dto: OTelSpanDTO,
|
|
703
703
|
) -> dict:
|
|
704
|
-
# TESTING
|
|
705
|
-
otel_span_dto.links = [
|
|
706
|
-
OTelLinkDTO(
|
|
707
|
-
context=otel_span_dto.context,
|
|
708
|
-
attributes={"ag.type.link": "testcase"},
|
|
709
|
-
)
|
|
710
|
-
]
|
|
711
|
-
# -------
|
|
712
|
-
|
|
713
704
|
# LINKS
|
|
714
705
|
links = None
|
|
715
706
|
otel_links = None
|
|
@@ -926,8 +917,9 @@ def parse_to_agenta_span_dto(
|
|
|
926
917
|
if span_dto.refs:
|
|
927
918
|
span_dto.refs = _unmarshal_attributes(span_dto.refs)
|
|
928
919
|
|
|
929
|
-
|
|
930
|
-
link
|
|
920
|
+
if isinstance(span_dto.links, list):
|
|
921
|
+
for link in span_dto.links:
|
|
922
|
+
link.tree_id = None
|
|
931
923
|
|
|
932
924
|
if span_dto.nodes:
|
|
933
925
|
for v in span_dto.nodes.values():
|
|
@@ -1030,6 +1022,24 @@ def _parse_readable_spans(
|
|
|
1030
1022
|
otel_span_dtos = list()
|
|
1031
1023
|
|
|
1032
1024
|
for span in spans:
|
|
1025
|
+
otel_events = [
|
|
1026
|
+
OTelEventDTO(
|
|
1027
|
+
name=event.name,
|
|
1028
|
+
timestamp=_timestamp_ns_to_datetime(event.timestamp),
|
|
1029
|
+
attributes=event.attributes,
|
|
1030
|
+
)
|
|
1031
|
+
for event in span.events
|
|
1032
|
+
]
|
|
1033
|
+
otel_links = [
|
|
1034
|
+
OTelLinkDTO(
|
|
1035
|
+
context=OTelContextDTO(
|
|
1036
|
+
trace_id=_int_to_hex(link.context.trace_id, 128),
|
|
1037
|
+
span_id=_int_to_hex(link.context.span_id, 64),
|
|
1038
|
+
),
|
|
1039
|
+
attributes=link.attributes,
|
|
1040
|
+
)
|
|
1041
|
+
for link in span.links
|
|
1042
|
+
]
|
|
1033
1043
|
otel_span_dto = OTelSpanDTO(
|
|
1034
1044
|
context=OTelContextDTO(
|
|
1035
1045
|
trace_id=_int_to_hex(span.get_span_context().trace_id, 128),
|
|
@@ -1045,14 +1055,7 @@ def _parse_readable_spans(
|
|
|
1045
1055
|
status_code=OTelStatusCode("STATUS_CODE_" + span.status.status_code.name),
|
|
1046
1056
|
status_message=span.status.description,
|
|
1047
1057
|
attributes=span.attributes,
|
|
1048
|
-
events=
|
|
1049
|
-
OTelEventDTO(
|
|
1050
|
-
name=event.name,
|
|
1051
|
-
timestamp=_timestamp_ns_to_datetime(event.timestamp),
|
|
1052
|
-
attributes=event.attributes,
|
|
1053
|
-
)
|
|
1054
|
-
for event in span.events
|
|
1055
|
-
],
|
|
1058
|
+
events=otel_events if len(otel_events) > 0 else None,
|
|
1056
1059
|
parent=(
|
|
1057
1060
|
OTelContextDTO(
|
|
1058
1061
|
trace_id=_int_to_hex(span.parent.trace_id, 128),
|
|
@@ -1061,16 +1064,7 @@ def _parse_readable_spans(
|
|
|
1061
1064
|
if span.parent
|
|
1062
1065
|
else None
|
|
1063
1066
|
),
|
|
1064
|
-
links=
|
|
1065
|
-
OTelLinkDTO(
|
|
1066
|
-
context=OTelContextDTO(
|
|
1067
|
-
trace_id=_int_to_hex(link.context.trace_id, 128),
|
|
1068
|
-
span_id=_int_to_hex(link.context.span_id, 64),
|
|
1069
|
-
),
|
|
1070
|
-
attributes=link.attributes,
|
|
1071
|
-
)
|
|
1072
|
-
for link in span.links
|
|
1073
|
-
],
|
|
1067
|
+
links=otel_links if len(otel_links) > 0 else None,
|
|
1074
1068
|
)
|
|
1075
1069
|
|
|
1076
1070
|
otel_span_dtos.append(otel_span_dto)
|