agenta 0.50.6__py3-none-any.whl → 0.51.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agenta might be problematic. Click here for more details.

@@ -0,0 +1,253 @@
1
+ from typing import Callable
2
+ from inspect import signature
3
+ from uuid import uuid4
4
+
5
+ from agenta.sdk.utils.logging import get_module_logger
6
+
7
+ from agenta.sdk.middleware.base import (
8
+ WorkflowMiddleware,
9
+ middleware_as_decorator,
10
+ )
11
+ from agenta.sdk.workflows.types import (
12
+ WorkflowServiceRequest,
13
+ WorkflowServiceResponse,
14
+ WorkflowServiceData,
15
+ WorkflowRevision,
16
+ Status,
17
+ )
18
+
19
+ log = get_module_logger(__name__)
20
+
21
+ DEFAULT_INPUTS_MAPPINGS = {
22
+ "request": "request",
23
+ "revision": "revision",
24
+ "parameters": "revision.data.parameters",
25
+ "inputs": "request.data.inputs",
26
+ "outputs": "request.data.traces.0.attributes.ag.data.outputs",
27
+ "trace": "request.data.traces.0",
28
+ "trace_outputs": "request.data.traces.0.attributes.ag.data.outputs",
29
+ "traces": "request.data.traces",
30
+ "traces_outputs": "request.data.traces.{}.attributes.ag.data.outputs",
31
+ }
32
+
33
+ ALLOWED_INPUTS_KEYS = set(DEFAULT_INPUTS_MAPPINGS.keys())
34
+
35
+ ALLOWED_OUTPUTS_KEYS = {
36
+ "outputs",
37
+ "trace",
38
+ }
39
+
40
+ DEFAULT_MAPPINGS = {}
41
+
42
+ CURRENT_VERSION = "2025.07.14"
43
+
44
+
45
+ @middleware_as_decorator
46
+ class AdaptMiddleware(WorkflowMiddleware):
47
+ def __init__(self):
48
+ pass
49
+
50
+ async def __call__(
51
+ self,
52
+ request: WorkflowServiceRequest,
53
+ revision: WorkflowRevision,
54
+ handler: Callable,
55
+ ) -> WorkflowServiceResponse:
56
+ request_data_dict = request.data.model_dump(
57
+ mode="json",
58
+ exclude_none=True,
59
+ )
60
+
61
+ revision_data_dict = revision.data.model_dump(
62
+ mode="json",
63
+ exclude_none=True,
64
+ )
65
+
66
+ provided_request_keys = sorted(
67
+ {"request", "revision", "parameters"} | set(request_data_dict.keys())
68
+ )
69
+
70
+ handler_signature = signature(handler)
71
+
72
+ requested_inputs_keys = sorted(set(handler_signature.parameters.keys()))
73
+
74
+ kwargs = dict()
75
+
76
+ try:
77
+ for requested_input_key in requested_inputs_keys:
78
+ if requested_input_key not in ALLOWED_INPUTS_KEYS:
79
+ kwargs[requested_input_key] = None
80
+ continue
81
+
82
+ if requested_input_key in provided_request_keys:
83
+ if requested_input_key == "parameters":
84
+ kwargs[requested_input_key] = (
85
+ revision.data.parameters
86
+ if revision.data.parameters
87
+ else None
88
+ )
89
+ elif requested_input_key == "request":
90
+ kwargs[requested_input_key] = request
91
+ elif requested_input_key == "revision":
92
+ kwargs[requested_input_key] = revision
93
+ else:
94
+ kwargs[requested_input_key] = request_data_dict[
95
+ requested_input_key
96
+ ]
97
+
98
+ else:
99
+ kwargs[requested_input_key] = self._apply_request_mapping(
100
+ request=request_data_dict,
101
+ revision=revision_data_dict,
102
+ key=requested_input_key,
103
+ )
104
+
105
+ except: # pylint: disable=bare-except
106
+ # handle the error
107
+ pass
108
+
109
+ try:
110
+ # inputs = kwargs.get("inputs", None)
111
+
112
+ # inputs_schema =
113
+
114
+ # self._check_request_schema(
115
+ # inputs,
116
+ # inputs_schema,
117
+ # )
118
+
119
+ # parameters = kwargs.get("parameters", None)
120
+
121
+ # parameters_schema =
122
+
123
+ # self._check_request_schema(
124
+ # parameters,
125
+ # parameters_schema,
126
+ # )
127
+
128
+ pass
129
+
130
+ except: # pylint: disable=bare-except
131
+ # handle the error
132
+ pass
133
+
134
+ try:
135
+ handler_signature.bind(**kwargs)
136
+
137
+ except: # pylint: disable=bare-except
138
+ # handle the error
139
+ pass
140
+
141
+ try:
142
+ outputs = await handler(**kwargs)
143
+
144
+ trace = None # get trace
145
+
146
+ except: # pylint: disable=bare-except
147
+ # handle the error
148
+ log.debug()
149
+ raise
150
+
151
+ try:
152
+ # outputs_schema =
153
+
154
+ # self._check_request_schema(
155
+ # outputs,
156
+ # outputs_schema,
157
+ # )
158
+
159
+ pass
160
+
161
+ except: # pylint: disable=bare-except
162
+ # handle the error
163
+ pass
164
+
165
+ return WorkflowServiceResponse(
166
+ id=uuid4(),
167
+ version=CURRENT_VERSION,
168
+ # status=Status(code=200, message="Success"),
169
+ data=WorkflowServiceData(
170
+ outputs=outputs,
171
+ trace=trace,
172
+ ),
173
+ )
174
+
175
+ def _apply_request_mapping(
176
+ self,
177
+ request: dict,
178
+ revision: dict,
179
+ key: str,
180
+ ):
181
+ mapping = DEFAULT_INPUTS_MAPPINGS[key]
182
+
183
+ parts = mapping.split(".")
184
+
185
+ base_part = parts.pop(0)
186
+ data_part = parts.pop(0)
187
+
188
+ base = (
189
+ request
190
+ if base_part == "request" and data_part == "data"
191
+ else (revision if base_part == "revision" and data_part == "data" else {})
192
+ )
193
+
194
+ scalar = True
195
+ is_list = False
196
+ is_dict = False
197
+
198
+ for part in parts:
199
+ _is_index = part.isdigit()
200
+ _is_list = part == "[]"
201
+ _is_dict = part == "{}"
202
+
203
+ _scalar = not (_is_list or _is_dict)
204
+
205
+ if not scalar and not _scalar:
206
+ # handle error once we start using mappings
207
+ pass
208
+
209
+ if _is_index:
210
+ if isinstance(base, list):
211
+ base = base[int(part)]
212
+ elif isinstance(base, dict):
213
+ base = base[list(base.keys())[int(part)]]
214
+ else:
215
+ # handle error once we start using mappings
216
+ pass
217
+
218
+ elif _is_list:
219
+ if not isinstance(base, list):
220
+ # handle error once we start using mappings
221
+ pass
222
+ elif _is_dict:
223
+ if not isinstance(base, dict):
224
+ # handle error once we start using mappings
225
+ pass
226
+
227
+ else:
228
+ if isinstance(base, dict):
229
+ if is_list:
230
+ base = [
231
+ (item.get(part, None) if isinstance(item, dict) else None)
232
+ for item in base
233
+ ]
234
+ elif is_dict:
235
+ base = {
236
+ key: (
237
+ value.get(part, None)
238
+ if isinstance(value, dict)
239
+ else None
240
+ )
241
+ for key, value in base.items()
242
+ }
243
+ else:
244
+ base = base.get(part, None)
245
+ else:
246
+ # handle error once we start using mappings
247
+ pass
248
+
249
+ scalar = _scalar
250
+ is_list = _is_list
251
+ is_dict = _is_dict
252
+
253
+ return base
@@ -15,6 +15,17 @@ from agenta.sdk.utils.logging import get_module_logger
15
15
 
16
16
  import agenta as ag
17
17
 
18
+ from agenta.sdk.middleware.base import (
19
+ WorkflowMiddleware,
20
+ middleware_as_decorator,
21
+ )
22
+ from agenta.sdk.workflows.types import (
23
+ WorkflowServiceRequest,
24
+ WorkflowServiceResponse,
25
+ WorkflowRevision,
26
+ WorkflowServiceHandler,
27
+ )
28
+
18
29
  log = get_module_logger(__name__)
19
30
 
20
31
  AGENTA_RUNTIME_PREFIX = getenv("AGENTA_RUNTIME_PREFIX", "")
@@ -53,7 +64,7 @@ class DenyException(Exception):
53
64
  self.content = content
54
65
 
55
66
 
56
- class AuthMiddleware(BaseHTTPMiddleware):
67
+ class AuthHTTPMiddleware(BaseHTTPMiddleware):
57
68
  def __init__(self, app: FastAPI):
58
69
  super().__init__(app)
59
70
 
@@ -157,7 +168,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
157
168
  # log.debug(f"Timeout error while verify credentials: {exc}")
158
169
  raise DenyException(
159
170
  status_code=504,
160
- content="Could not verify credentials: connection to {self.host} timed out. Please check your network connection.",
171
+ content=f"Could not verify credentials: connection to {self.host} timed out. Please check your network connection.",
161
172
  ) from exc
162
173
  except httpx.ConnectError as exc:
163
174
  # log.debug(f"Connection error while verify credentials: {exc}")
@@ -169,7 +180,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
169
180
  # log.debug(f"Network error while verify credentials: {exc}")
170
181
  raise DenyException(
171
182
  status_code=503,
172
- content="Could not verify credentials: connection to {self.host} failed. Please check your network connection.",
183
+ content=f"Could not verify credentials: connection to {self.host} failed. Please check your network connection.",
173
184
  ) from exc
174
185
  except httpx.HTTPError as exc:
175
186
  # log.debug(f"HTTP error while verify credentials: {exc}")
@@ -253,3 +264,168 @@ class AuthMiddleware(BaseHTTPMiddleware):
253
264
  status_code=500,
254
265
  content=f"Could not verify credentials: unexpected error - {str(exc)}. Please try again later or contact support if the issue persists.",
255
266
  ) from exc
267
+
268
+
269
+ from agenta.sdk.context.tracing import (
270
+ tracing_context_manager,
271
+ tracing_context,
272
+ TracingContext,
273
+ )
274
+
275
+
276
+ @middleware_as_decorator
277
+ class AuthMiddleware(WorkflowMiddleware):
278
+ def __init__(self):
279
+ pass
280
+
281
+ async def __call__(
282
+ self,
283
+ request: WorkflowServiceRequest,
284
+ revision: WorkflowRevision,
285
+ handler: WorkflowServiceHandler,
286
+ ) -> WorkflowServiceResponse:
287
+ ctx = tracing_context.get()
288
+
289
+ ctx.credentials = request.credentials
290
+
291
+ with tracing_context_manager(context=ctx):
292
+ return await handler(request, revision)
293
+
294
+
295
+ # @middleware_as_decorator
296
+ # class AuthMiddleware(WorkflowMiddleware):
297
+ # def __init__(self):
298
+ # self.host = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host
299
+ # self.scope_type = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.scope_type
300
+ # self.scope_id = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.scope_id
301
+
302
+ # async def __call__(
303
+ # self,
304
+ # request: WorkflowServiceRequest,
305
+ # revision: WorkflowRevision,
306
+ # handler: WorkflowServiceHandler,
307
+ # ) -> WorkflowServiceResponse:
308
+ # try:
309
+ # request.credentials = await self._get_credentials(request)
310
+
311
+ # except DenyException as deny:
312
+ # display_exception("Auth Middleware Exception")
313
+
314
+ # raise deny
315
+
316
+ # except Exception as exc:
317
+ # display_exception("Auth Middleware Exception")
318
+
319
+ # raise DenyException(
320
+ # status_code=500,
321
+ # content="Auth Middleware Unexpected Error.",
322
+ # ) from exc
323
+
324
+ # return await handler(request, revision)
325
+
326
+ # async def _get_credentials(
327
+ # self,
328
+ # request: WorkflowServiceRequest,
329
+ # ) -> Optional[str]:
330
+ # credentials = request.credentials
331
+
332
+ # headers = {"Authorization": credentials} if credentials else None
333
+
334
+ # params = {
335
+ # "action": "run_workflow",
336
+ # "resource_type": "workflow",
337
+ # }
338
+ # if self.scope_type and self.scope_id:
339
+ # params["scope_type"] = self.scope_type
340
+ # params["scope_id"] = self.scope_id
341
+
342
+ # _hash = dumps(
343
+ # {
344
+ # "headers": headers,
345
+ # "params": params,
346
+ # },
347
+ # sort_keys=True,
348
+ # )
349
+
350
+ # if _CACHE_ENABLED:
351
+ # cached = _cache.get(_hash)
352
+ # if cached:
353
+ # return cached
354
+
355
+ # try:
356
+ # async with httpx.AsyncClient() as client:
357
+ # response = await client.get(
358
+ # f"{self.host}/api/permissions/verify",
359
+ # params=params,
360
+ # headers=headers,
361
+ # timeout=5 * 60,
362
+ # )
363
+
364
+ # except httpx.TimeoutException as exc:
365
+ # raise DenyException(
366
+ # status_code=504,
367
+ # content=f"Could not verify credentials: connection to {self.host} timed out. Please check your network connection.",
368
+ # ) from exc
369
+ # except httpx.ConnectError as exc:
370
+ # raise DenyException(
371
+ # status_code=503,
372
+ # content=f"Could not verify credentials: connection to {self.host} failed. Please check if agenta is available.",
373
+ # ) from exc
374
+ # except httpx.NetworkError as exc:
375
+ # raise DenyException(
376
+ # status_code=503,
377
+ # content=f"Could not verify credentials: connection to {self.host} failed. Please check your network connection.",
378
+ # ) from exc
379
+ # except httpx.HTTPError as exc:
380
+ # raise DenyException(
381
+ # status_code=502,
382
+ # content=f"Could not verify credentials: connection to {self.host} failed. Please check if agenta is available.",
383
+ # ) from exc
384
+ # except Exception as exc:
385
+ # raise DenyException(
386
+ # 500,
387
+ # f"Could not verify credentials: unexpected error.\n {exc}",
388
+ # ) from exc
389
+
390
+ # if response.status_code == 401:
391
+ # raise DenyException(
392
+ # status_code=401,
393
+ # content="Invalid credentials. Please check your credentials or login again.",
394
+ # )
395
+ # if response.status_code == 403:
396
+ # raise DenyException(
397
+ # status_code=403,
398
+ # content="Permission denied. Please check your permissions or contact your administrator.",
399
+ # )
400
+ # if response.status_code != 200:
401
+ # raise DenyException(
402
+ # status_code=500,
403
+ # content=f"Could not verify credentials: {self.host} returned unexpected status code {response.status_code}. Please try again later or contact support if the issue persists.",
404
+ # )
405
+
406
+ # try:
407
+ # auth = response.json()
408
+ # except ValueError as exc:
409
+ # raise DenyException(
410
+ # status_code=500,
411
+ # content=f"Could not verify credentials: {self.host} returned unexpected invalid JSON response. Please try again later or contact support if the issue persists.",
412
+ # ) from exc
413
+
414
+ # if not isinstance(auth, dict):
415
+ # raise DenyException(
416
+ # status_code=500,
417
+ # content=f"Could not verify credentials: {self.host} returned unexpected invalid response format. Please try again later or contact support if the issue persists.",
418
+ # )
419
+
420
+ # if auth.get("effect") != "allow":
421
+ # raise DenyException(
422
+ # status_code=403,
423
+ # content="Permission denied. Please check your permissions or contact your administrator.",
424
+ # )
425
+
426
+ # credentials: str = auth.get("credentials")
427
+
428
+ # if credentials is not None:
429
+ # _cache.put(_hash, credentials)
430
+
431
+ # return credentials
@@ -0,0 +1,40 @@
1
+ from typing import Protocol, Callable, Any
2
+
3
+ from agenta.sdk.workflows.types import (
4
+ WorkflowServiceRequest,
5
+ WorkflowServiceResponse,
6
+ WorkflowRevision,
7
+ WorkflowServiceHandler,
8
+ )
9
+
10
+
11
+ class WorkflowMiddleware(Protocol):
12
+ async def __call__(
13
+ self,
14
+ request: WorkflowServiceRequest,
15
+ revision: WorkflowRevision,
16
+ handler: Callable,
17
+ ) -> Any:
18
+ ...
19
+
20
+
21
+ WorkflowMiddlewareDecorator = Callable[[WorkflowServiceHandler], WorkflowServiceHandler]
22
+
23
+
24
+ def middleware_as_decorator(
25
+ middleware: WorkflowMiddleware | type[WorkflowMiddleware],
26
+ ) -> WorkflowMiddlewareDecorator:
27
+ middleware = middleware() if isinstance(middleware, type) else middleware
28
+
29
+ def decorator(
30
+ handler: WorkflowServiceHandler,
31
+ ) -> WorkflowServiceHandler:
32
+ async def wrapped(
33
+ request: WorkflowServiceRequest,
34
+ revision: WorkflowRevision,
35
+ ) -> WorkflowServiceResponse:
36
+ return await middleware(request, revision, handler)
37
+
38
+ return wrapped
39
+
40
+ return decorator
@@ -0,0 +1,40 @@
1
+ from agenta.sdk.utils.logging import get_module_logger
2
+
3
+ from agenta.sdk.middleware.base import (
4
+ WorkflowMiddleware,
5
+ middleware_as_decorator,
6
+ )
7
+ from agenta.sdk.workflows.types import (
8
+ WorkflowServiceRequest,
9
+ WorkflowServiceResponse,
10
+ WorkflowRevision,
11
+ WorkflowServiceHandler,
12
+ )
13
+
14
+ from agenta.sdk.context.tracing import (
15
+ tracing_context_manager,
16
+ tracing_context,
17
+ )
18
+
19
+
20
+ log = get_module_logger(__name__)
21
+
22
+
23
+ @middleware_as_decorator
24
+ class FlagsMiddleware(WorkflowMiddleware):
25
+ def __init__(self):
26
+ pass
27
+
28
+ async def __call__(
29
+ self,
30
+ request: WorkflowServiceRequest,
31
+ revision: WorkflowRevision,
32
+ handler: WorkflowServiceHandler,
33
+ ) -> WorkflowServiceResponse:
34
+ ctx = tracing_context.get()
35
+
36
+ if isinstance(request.flags, dict) and request.flags.get("is_annotation"):
37
+ ctx.type = "annotation"
38
+
39
+ with tracing_context_manager(context=ctx):
40
+ return await handler(request, revision)
@@ -8,15 +8,15 @@ from opentelemetry.sdk.trace.export import (
8
8
  ReadableSpan,
9
9
  )
10
10
 
11
+ from agenta.sdk.utils.logging import get_module_logger
11
12
  from agenta.sdk.utils.exceptions import suppress
12
- from agenta.sdk.context.exporting import (
13
- exporting_context_manager,
14
- exporting_context,
15
- ExportingContext,
16
- )
17
13
  from agenta.sdk.utils.cache import TTLLRUCache
14
+ from agenta.sdk.context.tracing import (
15
+ tracing_exporter_context_manager,
16
+ tracing_exporter_context,
17
+ TracingExporterContext,
18
+ )
18
19
 
19
- from agenta.sdk.utils.logging import get_module_logger
20
20
 
21
21
  log = get_module_logger(__name__)
22
22
 
@@ -101,8 +101,8 @@ class OTLPExporter(OTLPSpanExporter):
101
101
  serialized_spans = []
102
102
 
103
103
  for credentials, _spans in grouped_spans.items():
104
- with exporting_context_manager(
105
- context=ExportingContext(
104
+ with tracing_exporter_context_manager(
105
+ context=TracingExporterContext(
106
106
  credentials=credentials,
107
107
  )
108
108
  ):
@@ -114,7 +114,7 @@ class OTLPExporter(OTLPSpanExporter):
114
114
  return SpanExportResult.FAILURE
115
115
 
116
116
  def _export(self, serialized_data: bytes, timeout_sec: Optional[float] = None):
117
- credentials = exporting_context.get().credentials
117
+ credentials = tracing_exporter_context.get().credentials
118
118
 
119
119
  if credentials:
120
120
  self._session.headers.update({"Authorization": credentials})
@@ -62,6 +62,7 @@ Attributes = Dict[str, AttributeValueType]
62
62
  class TreeType(Enum):
63
63
  # --- VARIANTS --- #
64
64
  INVOCATION = "invocation"
65
+ ANNOTATION = "annotation"
65
66
  # --- VARIANTS --- #
66
67
 
67
68
 
@@ -1,9 +1,7 @@
1
1
  from typing import Optional, Any, Dict, Callable
2
2
  from enum import Enum
3
- from uuid import UUID
4
3
 
5
4
  from pydantic import BaseModel
6
- from httpx import get as check
7
5
 
8
6
 
9
7
  from opentelemetry.trace import (
@@ -33,7 +31,6 @@ from agenta.sdk.tracing.conventions import Reference, is_valid_attribute_key
33
31
  from agenta.sdk.tracing.propagation import extract, inject
34
32
  from agenta.sdk.utils.cache import TTLLRUCache
35
33
 
36
- from agenta.sdk.context.tracing import tracing_context
37
34
 
38
35
  log = get_module_logger(__name__)
39
36
 
@@ -120,11 +117,6 @@ class Tracing(metaclass=Singleton):
120
117
  # TRACE PROCESSORS -- OTLP
121
118
  try:
122
119
  log.info("Agenta - OLTP URL: %s", self.otlp_url)
123
- # check(
124
- # self.otlp_url,
125
- # headers=self.headers,
126
- # timeout=1,
127
- # )
128
120
 
129
121
  _otlp = TraceProcessor(
130
122
  OTLPExporter(
agenta/sdk/utils/cache.py CHANGED
@@ -5,7 +5,7 @@ from collections import OrderedDict
5
5
  from threading import Lock
6
6
 
7
7
  CACHE_CAPACITY = int(getenv("AGENTA_MIDDLEWARE_CACHE_CAPACITY", "512"))
8
- CACHE_TTL = int(getenv("AGENTA_MIDDLEWARE_CACHE_TTL", str(60))) # 1 minute
8
+ CACHE_TTL = int(getenv("AGENTA_MIDDLEWARE_CACHE_TTL", str(5 * 60))) # 5 minutes
9
9
 
10
10
 
11
11
  class TTLLRUCache:
@@ -38,7 +38,7 @@ class TTLLRUCache:
38
38
 
39
39
  return value
40
40
 
41
- def put(self, key, value):
41
+ def put(self, key, value, ttl: Optional[int] = None):
42
42
  with self.lock:
43
43
  try:
44
44
  # LRU update
@@ -50,4 +50,4 @@ class TTLLRUCache:
50
50
  self.cache.popitem(last=False)
51
51
 
52
52
  # Put
53
- self.cache[key] = (value, time() + self.ttl)
53
+ self.cache[key] = (value, time() + (ttl if ttl is not None else self.ttl))
File without changes
@@ -0,0 +1,32 @@
1
+ from json import dumps
2
+
3
+ from agenta.sdk.utils.logging import get_module_logger
4
+ from agenta.sdk.workflows.types import Data
5
+
6
+
7
+ log = get_module_logger(__name__)
8
+
9
+
10
+ async def exact_match_v1(
11
+ *,
12
+ parameters: Data,
13
+ inputs: Data,
14
+ outputs: Data | str,
15
+ ) -> Data:
16
+ success = False
17
+
18
+ try:
19
+ reference_key = parameters.get("reference_key", None)
20
+ reference_outputs = inputs.get(reference_key, None)
21
+
22
+ if isinstance(outputs, str) and isinstance(reference_outputs, str):
23
+ success = outputs == reference_outputs
24
+ elif isinstance(outputs, dict) and isinstance(reference_outputs, dict):
25
+ outputs = dumps(outputs, sort_keys=True)
26
+ reference_outputs = dumps(reference_outputs, sort_keys=True)
27
+ success = outputs == reference_outputs
28
+
29
+ except: # pylint: disable=bare-except
30
+ log.error("Error in exact_match_v1", exc_info=True)
31
+
32
+ return {"success": success}