agenta 0.50.6__py3-none-any.whl → 0.51.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agenta might be problematic. Click here for more details.
- agenta/__init__.py +8 -5
- agenta/sdk/__init__.py +2 -1
- agenta/sdk/agenta_init.py +2 -2
- agenta/sdk/context/running.py +39 -0
- agenta/sdk/context/{routing.py → serving.py} +4 -4
- agenta/sdk/context/tracing.py +27 -1
- agenta/sdk/decorators/running.py +134 -0
- agenta/sdk/decorators/{routing.py → serving.py} +5 -5
- agenta/sdk/decorators/tracing.py +7 -1
- agenta/sdk/litellm/mockllm.py +55 -2
- agenta/sdk/managers/config.py +2 -2
- agenta/sdk/managers/secrets.py +2 -2
- agenta/sdk/managers/vault.py +2 -2
- agenta/sdk/middleware/adapt.py +253 -0
- agenta/sdk/middleware/auth.py +179 -3
- agenta/sdk/middleware/base.py +40 -0
- agenta/sdk/middleware/flags.py +40 -0
- agenta/sdk/tracing/exporters.py +9 -9
- agenta/sdk/tracing/inline.py +1 -0
- agenta/sdk/tracing/tracing.py +0 -8
- agenta/sdk/utils/cache.py +3 -3
- agenta/sdk/workflows/__init__.py +0 -0
- agenta/sdk/workflows/registry.py +32 -0
- agenta/sdk/workflows/types.py +470 -0
- agenta/sdk/workflows/utils.py +17 -0
- {agenta-0.50.6.dist-info → agenta-0.51.1.dist-info}/METADATA +4 -2
- {agenta-0.50.6.dist-info → agenta-0.51.1.dist-info}/RECORD +28 -20
- agenta/sdk/context/exporting.py +0 -25
- {agenta-0.50.6.dist-info → agenta-0.51.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
from inspect import signature
|
|
3
|
+
from uuid import uuid4
|
|
4
|
+
|
|
5
|
+
from agenta.sdk.utils.logging import get_module_logger
|
|
6
|
+
|
|
7
|
+
from agenta.sdk.middleware.base import (
|
|
8
|
+
WorkflowMiddleware,
|
|
9
|
+
middleware_as_decorator,
|
|
10
|
+
)
|
|
11
|
+
from agenta.sdk.workflows.types import (
|
|
12
|
+
WorkflowServiceRequest,
|
|
13
|
+
WorkflowServiceResponse,
|
|
14
|
+
WorkflowServiceData,
|
|
15
|
+
WorkflowRevision,
|
|
16
|
+
Status,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
log = get_module_logger(__name__)
|
|
20
|
+
|
|
21
|
+
DEFAULT_INPUTS_MAPPINGS = {
|
|
22
|
+
"request": "request",
|
|
23
|
+
"revision": "revision",
|
|
24
|
+
"parameters": "revision.data.parameters",
|
|
25
|
+
"inputs": "request.data.inputs",
|
|
26
|
+
"outputs": "request.data.traces.0.attributes.ag.data.outputs",
|
|
27
|
+
"trace": "request.data.traces.0",
|
|
28
|
+
"trace_outputs": "request.data.traces.0.attributes.ag.data.outputs",
|
|
29
|
+
"traces": "request.data.traces",
|
|
30
|
+
"traces_outputs": "request.data.traces.{}.attributes.ag.data.outputs",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
ALLOWED_INPUTS_KEYS = set(DEFAULT_INPUTS_MAPPINGS.keys())
|
|
34
|
+
|
|
35
|
+
ALLOWED_OUTPUTS_KEYS = {
|
|
36
|
+
"outputs",
|
|
37
|
+
"trace",
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
DEFAULT_MAPPINGS = {}
|
|
41
|
+
|
|
42
|
+
CURRENT_VERSION = "2025.07.14"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@middleware_as_decorator
|
|
46
|
+
class AdaptMiddleware(WorkflowMiddleware):
|
|
47
|
+
def __init__(self):
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
async def __call__(
|
|
51
|
+
self,
|
|
52
|
+
request: WorkflowServiceRequest,
|
|
53
|
+
revision: WorkflowRevision,
|
|
54
|
+
handler: Callable,
|
|
55
|
+
) -> WorkflowServiceResponse:
|
|
56
|
+
request_data_dict = request.data.model_dump(
|
|
57
|
+
mode="json",
|
|
58
|
+
exclude_none=True,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
revision_data_dict = revision.data.model_dump(
|
|
62
|
+
mode="json",
|
|
63
|
+
exclude_none=True,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
provided_request_keys = sorted(
|
|
67
|
+
{"request", "revision", "parameters"} | set(request_data_dict.keys())
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
handler_signature = signature(handler)
|
|
71
|
+
|
|
72
|
+
requested_inputs_keys = sorted(set(handler_signature.parameters.keys()))
|
|
73
|
+
|
|
74
|
+
kwargs = dict()
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
for requested_input_key in requested_inputs_keys:
|
|
78
|
+
if requested_input_key not in ALLOWED_INPUTS_KEYS:
|
|
79
|
+
kwargs[requested_input_key] = None
|
|
80
|
+
continue
|
|
81
|
+
|
|
82
|
+
if requested_input_key in provided_request_keys:
|
|
83
|
+
if requested_input_key == "parameters":
|
|
84
|
+
kwargs[requested_input_key] = (
|
|
85
|
+
revision.data.parameters
|
|
86
|
+
if revision.data.parameters
|
|
87
|
+
else None
|
|
88
|
+
)
|
|
89
|
+
elif requested_input_key == "request":
|
|
90
|
+
kwargs[requested_input_key] = request
|
|
91
|
+
elif requested_input_key == "revision":
|
|
92
|
+
kwargs[requested_input_key] = revision
|
|
93
|
+
else:
|
|
94
|
+
kwargs[requested_input_key] = request_data_dict[
|
|
95
|
+
requested_input_key
|
|
96
|
+
]
|
|
97
|
+
|
|
98
|
+
else:
|
|
99
|
+
kwargs[requested_input_key] = self._apply_request_mapping(
|
|
100
|
+
request=request_data_dict,
|
|
101
|
+
revision=revision_data_dict,
|
|
102
|
+
key=requested_input_key,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
except: # pylint: disable=bare-except
|
|
106
|
+
# handle the error
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
try:
|
|
110
|
+
# inputs = kwargs.get("inputs", None)
|
|
111
|
+
|
|
112
|
+
# inputs_schema =
|
|
113
|
+
|
|
114
|
+
# self._check_request_schema(
|
|
115
|
+
# inputs,
|
|
116
|
+
# inputs_schema,
|
|
117
|
+
# )
|
|
118
|
+
|
|
119
|
+
# parameters = kwargs.get("parameters", None)
|
|
120
|
+
|
|
121
|
+
# parameters_schema =
|
|
122
|
+
|
|
123
|
+
# self._check_request_schema(
|
|
124
|
+
# parameters,
|
|
125
|
+
# parameters_schema,
|
|
126
|
+
# )
|
|
127
|
+
|
|
128
|
+
pass
|
|
129
|
+
|
|
130
|
+
except: # pylint: disable=bare-except
|
|
131
|
+
# handle the error
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
try:
|
|
135
|
+
handler_signature.bind(**kwargs)
|
|
136
|
+
|
|
137
|
+
except: # pylint: disable=bare-except
|
|
138
|
+
# handle the error
|
|
139
|
+
pass
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
outputs = await handler(**kwargs)
|
|
143
|
+
|
|
144
|
+
trace = None # get trace
|
|
145
|
+
|
|
146
|
+
except: # pylint: disable=bare-except
|
|
147
|
+
# handle the error
|
|
148
|
+
log.debug()
|
|
149
|
+
raise
|
|
150
|
+
|
|
151
|
+
try:
|
|
152
|
+
# outputs_schema =
|
|
153
|
+
|
|
154
|
+
# self._check_request_schema(
|
|
155
|
+
# outputs,
|
|
156
|
+
# outputs_schema,
|
|
157
|
+
# )
|
|
158
|
+
|
|
159
|
+
pass
|
|
160
|
+
|
|
161
|
+
except: # pylint: disable=bare-except
|
|
162
|
+
# handle the error
|
|
163
|
+
pass
|
|
164
|
+
|
|
165
|
+
return WorkflowServiceResponse(
|
|
166
|
+
id=uuid4(),
|
|
167
|
+
version=CURRENT_VERSION,
|
|
168
|
+
# status=Status(code=200, message="Success"),
|
|
169
|
+
data=WorkflowServiceData(
|
|
170
|
+
outputs=outputs,
|
|
171
|
+
trace=trace,
|
|
172
|
+
),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
def _apply_request_mapping(
|
|
176
|
+
self,
|
|
177
|
+
request: dict,
|
|
178
|
+
revision: dict,
|
|
179
|
+
key: str,
|
|
180
|
+
):
|
|
181
|
+
mapping = DEFAULT_INPUTS_MAPPINGS[key]
|
|
182
|
+
|
|
183
|
+
parts = mapping.split(".")
|
|
184
|
+
|
|
185
|
+
base_part = parts.pop(0)
|
|
186
|
+
data_part = parts.pop(0)
|
|
187
|
+
|
|
188
|
+
base = (
|
|
189
|
+
request
|
|
190
|
+
if base_part == "request" and data_part == "data"
|
|
191
|
+
else (revision if base_part == "revision" and data_part == "data" else {})
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
scalar = True
|
|
195
|
+
is_list = False
|
|
196
|
+
is_dict = False
|
|
197
|
+
|
|
198
|
+
for part in parts:
|
|
199
|
+
_is_index = part.isdigit()
|
|
200
|
+
_is_list = part == "[]"
|
|
201
|
+
_is_dict = part == "{}"
|
|
202
|
+
|
|
203
|
+
_scalar = not (_is_list or _is_dict)
|
|
204
|
+
|
|
205
|
+
if not scalar and not _scalar:
|
|
206
|
+
# handle error once we start using mappings
|
|
207
|
+
pass
|
|
208
|
+
|
|
209
|
+
if _is_index:
|
|
210
|
+
if isinstance(base, list):
|
|
211
|
+
base = base[int(part)]
|
|
212
|
+
elif isinstance(base, dict):
|
|
213
|
+
base = base[list(base.keys())[int(part)]]
|
|
214
|
+
else:
|
|
215
|
+
# handle error once we start using mappings
|
|
216
|
+
pass
|
|
217
|
+
|
|
218
|
+
elif _is_list:
|
|
219
|
+
if not isinstance(base, list):
|
|
220
|
+
# handle error once we start using mappings
|
|
221
|
+
pass
|
|
222
|
+
elif _is_dict:
|
|
223
|
+
if not isinstance(base, dict):
|
|
224
|
+
# handle error once we start using mappings
|
|
225
|
+
pass
|
|
226
|
+
|
|
227
|
+
else:
|
|
228
|
+
if isinstance(base, dict):
|
|
229
|
+
if is_list:
|
|
230
|
+
base = [
|
|
231
|
+
(item.get(part, None) if isinstance(item, dict) else None)
|
|
232
|
+
for item in base
|
|
233
|
+
]
|
|
234
|
+
elif is_dict:
|
|
235
|
+
base = {
|
|
236
|
+
key: (
|
|
237
|
+
value.get(part, None)
|
|
238
|
+
if isinstance(value, dict)
|
|
239
|
+
else None
|
|
240
|
+
)
|
|
241
|
+
for key, value in base.items()
|
|
242
|
+
}
|
|
243
|
+
else:
|
|
244
|
+
base = base.get(part, None)
|
|
245
|
+
else:
|
|
246
|
+
# handle error once we start using mappings
|
|
247
|
+
pass
|
|
248
|
+
|
|
249
|
+
scalar = _scalar
|
|
250
|
+
is_list = _is_list
|
|
251
|
+
is_dict = _is_dict
|
|
252
|
+
|
|
253
|
+
return base
|
agenta/sdk/middleware/auth.py
CHANGED
|
@@ -15,6 +15,17 @@ from agenta.sdk.utils.logging import get_module_logger
|
|
|
15
15
|
|
|
16
16
|
import agenta as ag
|
|
17
17
|
|
|
18
|
+
from agenta.sdk.middleware.base import (
|
|
19
|
+
WorkflowMiddleware,
|
|
20
|
+
middleware_as_decorator,
|
|
21
|
+
)
|
|
22
|
+
from agenta.sdk.workflows.types import (
|
|
23
|
+
WorkflowServiceRequest,
|
|
24
|
+
WorkflowServiceResponse,
|
|
25
|
+
WorkflowRevision,
|
|
26
|
+
WorkflowServiceHandler,
|
|
27
|
+
)
|
|
28
|
+
|
|
18
29
|
log = get_module_logger(__name__)
|
|
19
30
|
|
|
20
31
|
AGENTA_RUNTIME_PREFIX = getenv("AGENTA_RUNTIME_PREFIX", "")
|
|
@@ -53,7 +64,7 @@ class DenyException(Exception):
|
|
|
53
64
|
self.content = content
|
|
54
65
|
|
|
55
66
|
|
|
56
|
-
class
|
|
67
|
+
class AuthHTTPMiddleware(BaseHTTPMiddleware):
|
|
57
68
|
def __init__(self, app: FastAPI):
|
|
58
69
|
super().__init__(app)
|
|
59
70
|
|
|
@@ -157,7 +168,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|
|
157
168
|
# log.debug(f"Timeout error while verify credentials: {exc}")
|
|
158
169
|
raise DenyException(
|
|
159
170
|
status_code=504,
|
|
160
|
-
content="Could not verify credentials: connection to {self.host} timed out. Please check your network connection.",
|
|
171
|
+
content=f"Could not verify credentials: connection to {self.host} timed out. Please check your network connection.",
|
|
161
172
|
) from exc
|
|
162
173
|
except httpx.ConnectError as exc:
|
|
163
174
|
# log.debug(f"Connection error while verify credentials: {exc}")
|
|
@@ -169,7 +180,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|
|
169
180
|
# log.debug(f"Network error while verify credentials: {exc}")
|
|
170
181
|
raise DenyException(
|
|
171
182
|
status_code=503,
|
|
172
|
-
content="Could not verify credentials: connection to {self.host} failed. Please check your network connection.",
|
|
183
|
+
content=f"Could not verify credentials: connection to {self.host} failed. Please check your network connection.",
|
|
173
184
|
) from exc
|
|
174
185
|
except httpx.HTTPError as exc:
|
|
175
186
|
# log.debug(f"HTTP error while verify credentials: {exc}")
|
|
@@ -253,3 +264,168 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|
|
253
264
|
status_code=500,
|
|
254
265
|
content=f"Could not verify credentials: unexpected error - {str(exc)}. Please try again later or contact support if the issue persists.",
|
|
255
266
|
) from exc
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
from agenta.sdk.context.tracing import (
|
|
270
|
+
tracing_context_manager,
|
|
271
|
+
tracing_context,
|
|
272
|
+
TracingContext,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
@middleware_as_decorator
|
|
277
|
+
class AuthMiddleware(WorkflowMiddleware):
|
|
278
|
+
def __init__(self):
|
|
279
|
+
pass
|
|
280
|
+
|
|
281
|
+
async def __call__(
|
|
282
|
+
self,
|
|
283
|
+
request: WorkflowServiceRequest,
|
|
284
|
+
revision: WorkflowRevision,
|
|
285
|
+
handler: WorkflowServiceHandler,
|
|
286
|
+
) -> WorkflowServiceResponse:
|
|
287
|
+
ctx = tracing_context.get()
|
|
288
|
+
|
|
289
|
+
ctx.credentials = request.credentials
|
|
290
|
+
|
|
291
|
+
with tracing_context_manager(context=ctx):
|
|
292
|
+
return await handler(request, revision)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
# @middleware_as_decorator
|
|
296
|
+
# class AuthMiddleware(WorkflowMiddleware):
|
|
297
|
+
# def __init__(self):
|
|
298
|
+
# self.host = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host
|
|
299
|
+
# self.scope_type = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.scope_type
|
|
300
|
+
# self.scope_id = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.scope_id
|
|
301
|
+
|
|
302
|
+
# async def __call__(
|
|
303
|
+
# self,
|
|
304
|
+
# request: WorkflowServiceRequest,
|
|
305
|
+
# revision: WorkflowRevision,
|
|
306
|
+
# handler: WorkflowServiceHandler,
|
|
307
|
+
# ) -> WorkflowServiceResponse:
|
|
308
|
+
# try:
|
|
309
|
+
# request.credentials = await self._get_credentials(request)
|
|
310
|
+
|
|
311
|
+
# except DenyException as deny:
|
|
312
|
+
# display_exception("Auth Middleware Exception")
|
|
313
|
+
|
|
314
|
+
# raise deny
|
|
315
|
+
|
|
316
|
+
# except Exception as exc:
|
|
317
|
+
# display_exception("Auth Middleware Exception")
|
|
318
|
+
|
|
319
|
+
# raise DenyException(
|
|
320
|
+
# status_code=500,
|
|
321
|
+
# content="Auth Middleware Unexpected Error.",
|
|
322
|
+
# ) from exc
|
|
323
|
+
|
|
324
|
+
# return await handler(request, revision)
|
|
325
|
+
|
|
326
|
+
# async def _get_credentials(
|
|
327
|
+
# self,
|
|
328
|
+
# request: WorkflowServiceRequest,
|
|
329
|
+
# ) -> Optional[str]:
|
|
330
|
+
# credentials = request.credentials
|
|
331
|
+
|
|
332
|
+
# headers = {"Authorization": credentials} if credentials else None
|
|
333
|
+
|
|
334
|
+
# params = {
|
|
335
|
+
# "action": "run_workflow",
|
|
336
|
+
# "resource_type": "workflow",
|
|
337
|
+
# }
|
|
338
|
+
# if self.scope_type and self.scope_id:
|
|
339
|
+
# params["scope_type"] = self.scope_type
|
|
340
|
+
# params["scope_id"] = self.scope_id
|
|
341
|
+
|
|
342
|
+
# _hash = dumps(
|
|
343
|
+
# {
|
|
344
|
+
# "headers": headers,
|
|
345
|
+
# "params": params,
|
|
346
|
+
# },
|
|
347
|
+
# sort_keys=True,
|
|
348
|
+
# )
|
|
349
|
+
|
|
350
|
+
# if _CACHE_ENABLED:
|
|
351
|
+
# cached = _cache.get(_hash)
|
|
352
|
+
# if cached:
|
|
353
|
+
# return cached
|
|
354
|
+
|
|
355
|
+
# try:
|
|
356
|
+
# async with httpx.AsyncClient() as client:
|
|
357
|
+
# response = await client.get(
|
|
358
|
+
# f"{self.host}/api/permissions/verify",
|
|
359
|
+
# params=params,
|
|
360
|
+
# headers=headers,
|
|
361
|
+
# timeout=5 * 60,
|
|
362
|
+
# )
|
|
363
|
+
|
|
364
|
+
# except httpx.TimeoutException as exc:
|
|
365
|
+
# raise DenyException(
|
|
366
|
+
# status_code=504,
|
|
367
|
+
# content=f"Could not verify credentials: connection to {self.host} timed out. Please check your network connection.",
|
|
368
|
+
# ) from exc
|
|
369
|
+
# except httpx.ConnectError as exc:
|
|
370
|
+
# raise DenyException(
|
|
371
|
+
# status_code=503,
|
|
372
|
+
# content=f"Could not verify credentials: connection to {self.host} failed. Please check if agenta is available.",
|
|
373
|
+
# ) from exc
|
|
374
|
+
# except httpx.NetworkError as exc:
|
|
375
|
+
# raise DenyException(
|
|
376
|
+
# status_code=503,
|
|
377
|
+
# content=f"Could not verify credentials: connection to {self.host} failed. Please check your network connection.",
|
|
378
|
+
# ) from exc
|
|
379
|
+
# except httpx.HTTPError as exc:
|
|
380
|
+
# raise DenyException(
|
|
381
|
+
# status_code=502,
|
|
382
|
+
# content=f"Could not verify credentials: connection to {self.host} failed. Please check if agenta is available.",
|
|
383
|
+
# ) from exc
|
|
384
|
+
# except Exception as exc:
|
|
385
|
+
# raise DenyException(
|
|
386
|
+
# 500,
|
|
387
|
+
# f"Could not verify credentials: unexpected error.\n {exc}",
|
|
388
|
+
# ) from exc
|
|
389
|
+
|
|
390
|
+
# if response.status_code == 401:
|
|
391
|
+
# raise DenyException(
|
|
392
|
+
# status_code=401,
|
|
393
|
+
# content="Invalid credentials. Please check your credentials or login again.",
|
|
394
|
+
# )
|
|
395
|
+
# if response.status_code == 403:
|
|
396
|
+
# raise DenyException(
|
|
397
|
+
# status_code=403,
|
|
398
|
+
# content="Permission denied. Please check your permissions or contact your administrator.",
|
|
399
|
+
# )
|
|
400
|
+
# if response.status_code != 200:
|
|
401
|
+
# raise DenyException(
|
|
402
|
+
# status_code=500,
|
|
403
|
+
# content=f"Could not verify credentials: {self.host} returned unexpected status code {response.status_code}. Please try again later or contact support if the issue persists.",
|
|
404
|
+
# )
|
|
405
|
+
|
|
406
|
+
# try:
|
|
407
|
+
# auth = response.json()
|
|
408
|
+
# except ValueError as exc:
|
|
409
|
+
# raise DenyException(
|
|
410
|
+
# status_code=500,
|
|
411
|
+
# content=f"Could not verify credentials: {self.host} returned unexpected invalid JSON response. Please try again later or contact support if the issue persists.",
|
|
412
|
+
# ) from exc
|
|
413
|
+
|
|
414
|
+
# if not isinstance(auth, dict):
|
|
415
|
+
# raise DenyException(
|
|
416
|
+
# status_code=500,
|
|
417
|
+
# content=f"Could not verify credentials: {self.host} returned unexpected invalid response format. Please try again later or contact support if the issue persists.",
|
|
418
|
+
# )
|
|
419
|
+
|
|
420
|
+
# if auth.get("effect") != "allow":
|
|
421
|
+
# raise DenyException(
|
|
422
|
+
# status_code=403,
|
|
423
|
+
# content="Permission denied. Please check your permissions or contact your administrator.",
|
|
424
|
+
# )
|
|
425
|
+
|
|
426
|
+
# credentials: str = auth.get("credentials")
|
|
427
|
+
|
|
428
|
+
# if credentials is not None:
|
|
429
|
+
# _cache.put(_hash, credentials)
|
|
430
|
+
|
|
431
|
+
# return credentials
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from typing import Protocol, Callable, Any
|
|
2
|
+
|
|
3
|
+
from agenta.sdk.workflows.types import (
|
|
4
|
+
WorkflowServiceRequest,
|
|
5
|
+
WorkflowServiceResponse,
|
|
6
|
+
WorkflowRevision,
|
|
7
|
+
WorkflowServiceHandler,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class WorkflowMiddleware(Protocol):
|
|
12
|
+
async def __call__(
|
|
13
|
+
self,
|
|
14
|
+
request: WorkflowServiceRequest,
|
|
15
|
+
revision: WorkflowRevision,
|
|
16
|
+
handler: Callable,
|
|
17
|
+
) -> Any:
|
|
18
|
+
...
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
WorkflowMiddlewareDecorator = Callable[[WorkflowServiceHandler], WorkflowServiceHandler]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def middleware_as_decorator(
|
|
25
|
+
middleware: WorkflowMiddleware | type[WorkflowMiddleware],
|
|
26
|
+
) -> WorkflowMiddlewareDecorator:
|
|
27
|
+
middleware = middleware() if isinstance(middleware, type) else middleware
|
|
28
|
+
|
|
29
|
+
def decorator(
|
|
30
|
+
handler: WorkflowServiceHandler,
|
|
31
|
+
) -> WorkflowServiceHandler:
|
|
32
|
+
async def wrapped(
|
|
33
|
+
request: WorkflowServiceRequest,
|
|
34
|
+
revision: WorkflowRevision,
|
|
35
|
+
) -> WorkflowServiceResponse:
|
|
36
|
+
return await middleware(request, revision, handler)
|
|
37
|
+
|
|
38
|
+
return wrapped
|
|
39
|
+
|
|
40
|
+
return decorator
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from agenta.sdk.utils.logging import get_module_logger
|
|
2
|
+
|
|
3
|
+
from agenta.sdk.middleware.base import (
|
|
4
|
+
WorkflowMiddleware,
|
|
5
|
+
middleware_as_decorator,
|
|
6
|
+
)
|
|
7
|
+
from agenta.sdk.workflows.types import (
|
|
8
|
+
WorkflowServiceRequest,
|
|
9
|
+
WorkflowServiceResponse,
|
|
10
|
+
WorkflowRevision,
|
|
11
|
+
WorkflowServiceHandler,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from agenta.sdk.context.tracing import (
|
|
15
|
+
tracing_context_manager,
|
|
16
|
+
tracing_context,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
log = get_module_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@middleware_as_decorator
|
|
24
|
+
class FlagsMiddleware(WorkflowMiddleware):
|
|
25
|
+
def __init__(self):
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
async def __call__(
|
|
29
|
+
self,
|
|
30
|
+
request: WorkflowServiceRequest,
|
|
31
|
+
revision: WorkflowRevision,
|
|
32
|
+
handler: WorkflowServiceHandler,
|
|
33
|
+
) -> WorkflowServiceResponse:
|
|
34
|
+
ctx = tracing_context.get()
|
|
35
|
+
|
|
36
|
+
if isinstance(request.flags, dict) and request.flags.get("is_annotation"):
|
|
37
|
+
ctx.type = "annotation"
|
|
38
|
+
|
|
39
|
+
with tracing_context_manager(context=ctx):
|
|
40
|
+
return await handler(request, revision)
|
agenta/sdk/tracing/exporters.py
CHANGED
|
@@ -8,15 +8,15 @@ from opentelemetry.sdk.trace.export import (
|
|
|
8
8
|
ReadableSpan,
|
|
9
9
|
)
|
|
10
10
|
|
|
11
|
+
from agenta.sdk.utils.logging import get_module_logger
|
|
11
12
|
from agenta.sdk.utils.exceptions import suppress
|
|
12
|
-
from agenta.sdk.context.exporting import (
|
|
13
|
-
exporting_context_manager,
|
|
14
|
-
exporting_context,
|
|
15
|
-
ExportingContext,
|
|
16
|
-
)
|
|
17
13
|
from agenta.sdk.utils.cache import TTLLRUCache
|
|
14
|
+
from agenta.sdk.context.tracing import (
|
|
15
|
+
tracing_exporter_context_manager,
|
|
16
|
+
tracing_exporter_context,
|
|
17
|
+
TracingExporterContext,
|
|
18
|
+
)
|
|
18
19
|
|
|
19
|
-
from agenta.sdk.utils.logging import get_module_logger
|
|
20
20
|
|
|
21
21
|
log = get_module_logger(__name__)
|
|
22
22
|
|
|
@@ -101,8 +101,8 @@ class OTLPExporter(OTLPSpanExporter):
|
|
|
101
101
|
serialized_spans = []
|
|
102
102
|
|
|
103
103
|
for credentials, _spans in grouped_spans.items():
|
|
104
|
-
with
|
|
105
|
-
context=
|
|
104
|
+
with tracing_exporter_context_manager(
|
|
105
|
+
context=TracingExporterContext(
|
|
106
106
|
credentials=credentials,
|
|
107
107
|
)
|
|
108
108
|
):
|
|
@@ -114,7 +114,7 @@ class OTLPExporter(OTLPSpanExporter):
|
|
|
114
114
|
return SpanExportResult.FAILURE
|
|
115
115
|
|
|
116
116
|
def _export(self, serialized_data: bytes, timeout_sec: Optional[float] = None):
|
|
117
|
-
credentials =
|
|
117
|
+
credentials = tracing_exporter_context.get().credentials
|
|
118
118
|
|
|
119
119
|
if credentials:
|
|
120
120
|
self._session.headers.update({"Authorization": credentials})
|
agenta/sdk/tracing/inline.py
CHANGED
agenta/sdk/tracing/tracing.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
from typing import Optional, Any, Dict, Callable
|
|
2
2
|
from enum import Enum
|
|
3
|
-
from uuid import UUID
|
|
4
3
|
|
|
5
4
|
from pydantic import BaseModel
|
|
6
|
-
from httpx import get as check
|
|
7
5
|
|
|
8
6
|
|
|
9
7
|
from opentelemetry.trace import (
|
|
@@ -33,7 +31,6 @@ from agenta.sdk.tracing.conventions import Reference, is_valid_attribute_key
|
|
|
33
31
|
from agenta.sdk.tracing.propagation import extract, inject
|
|
34
32
|
from agenta.sdk.utils.cache import TTLLRUCache
|
|
35
33
|
|
|
36
|
-
from agenta.sdk.context.tracing import tracing_context
|
|
37
34
|
|
|
38
35
|
log = get_module_logger(__name__)
|
|
39
36
|
|
|
@@ -120,11 +117,6 @@ class Tracing(metaclass=Singleton):
|
|
|
120
117
|
# TRACE PROCESSORS -- OTLP
|
|
121
118
|
try:
|
|
122
119
|
log.info("Agenta - OLTP URL: %s", self.otlp_url)
|
|
123
|
-
# check(
|
|
124
|
-
# self.otlp_url,
|
|
125
|
-
# headers=self.headers,
|
|
126
|
-
# timeout=1,
|
|
127
|
-
# )
|
|
128
120
|
|
|
129
121
|
_otlp = TraceProcessor(
|
|
130
122
|
OTLPExporter(
|
agenta/sdk/utils/cache.py
CHANGED
|
@@ -5,7 +5,7 @@ from collections import OrderedDict
|
|
|
5
5
|
from threading import Lock
|
|
6
6
|
|
|
7
7
|
CACHE_CAPACITY = int(getenv("AGENTA_MIDDLEWARE_CACHE_CAPACITY", "512"))
|
|
8
|
-
CACHE_TTL = int(getenv("AGENTA_MIDDLEWARE_CACHE_TTL", str(60))) #
|
|
8
|
+
CACHE_TTL = int(getenv("AGENTA_MIDDLEWARE_CACHE_TTL", str(5 * 60))) # 5 minutes
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class TTLLRUCache:
|
|
@@ -38,7 +38,7 @@ class TTLLRUCache:
|
|
|
38
38
|
|
|
39
39
|
return value
|
|
40
40
|
|
|
41
|
-
def put(self, key, value):
|
|
41
|
+
def put(self, key, value, ttl: Optional[int] = None):
|
|
42
42
|
with self.lock:
|
|
43
43
|
try:
|
|
44
44
|
# LRU update
|
|
@@ -50,4 +50,4 @@ class TTLLRUCache:
|
|
|
50
50
|
self.cache.popitem(last=False)
|
|
51
51
|
|
|
52
52
|
# Put
|
|
53
|
-
self.cache[key] = (value, time() + self.ttl)
|
|
53
|
+
self.cache[key] = (value, time() + (ttl if ttl is not None else self.ttl))
|
|
File without changes
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from json import dumps
|
|
2
|
+
|
|
3
|
+
from agenta.sdk.utils.logging import get_module_logger
|
|
4
|
+
from agenta.sdk.workflows.types import Data
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
log = get_module_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
async def exact_match_v1(
|
|
11
|
+
*,
|
|
12
|
+
parameters: Data,
|
|
13
|
+
inputs: Data,
|
|
14
|
+
outputs: Data | str,
|
|
15
|
+
) -> Data:
|
|
16
|
+
success = False
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
reference_key = parameters.get("reference_key", None)
|
|
20
|
+
reference_outputs = inputs.get(reference_key, None)
|
|
21
|
+
|
|
22
|
+
if isinstance(outputs, str) and isinstance(reference_outputs, str):
|
|
23
|
+
success = outputs == reference_outputs
|
|
24
|
+
elif isinstance(outputs, dict) and isinstance(reference_outputs, dict):
|
|
25
|
+
outputs = dumps(outputs, sort_keys=True)
|
|
26
|
+
reference_outputs = dumps(reference_outputs, sort_keys=True)
|
|
27
|
+
success = outputs == reference_outputs
|
|
28
|
+
|
|
29
|
+
except: # pylint: disable=bare-except
|
|
30
|
+
log.error("Error in exact_match_v1", exc_info=True)
|
|
31
|
+
|
|
32
|
+
return {"success": success}
|