agenta 0.32.0a1__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. agenta/__init__.py +2 -0
  2. agenta/client/backend/__init__.py +39 -31
  3. agenta/client/backend/admin/__init__.py +1 -0
  4. agenta/client/backend/admin/client.py +576 -0
  5. agenta/client/backend/apps/client.py +450 -68
  6. agenta/client/backend/bases/client.py +10 -10
  7. agenta/client/backend/client.py +112 -122
  8. agenta/client/backend/containers/client.py +70 -28
  9. agenta/client/backend/core/http_client.py +3 -3
  10. agenta/client/backend/environments/client.py +8 -8
  11. agenta/client/backend/evaluations/client.py +46 -52
  12. agenta/client/backend/evaluators/client.py +32 -32
  13. agenta/client/backend/human_evaluations/__init__.py +1 -0
  14. agenta/client/backend/human_evaluations/client.py +1692 -0
  15. agenta/client/backend/observability/__init__.py +4 -0
  16. agenta/client/backend/observability/client.py +221 -744
  17. agenta/client/backend/testsets/client.py +38 -202
  18. agenta/client/backend/types/__init__.py +34 -28
  19. agenta/client/backend/types/account_response.py +24 -0
  20. agenta/client/backend/types/app_variant_revision.py +2 -1
  21. agenta/client/backend/types/{create_trace_response.py → delete_evaluation.py} +2 -3
  22. agenta/client/backend/types/{evaluation_scenario_score_update.py → legacy_scope_request.py} +2 -2
  23. agenta/client/backend/types/legacy_scopes_response.py +29 -0
  24. agenta/client/backend/types/{human_evaluation_update.py → legacy_user_request.py} +4 -4
  25. agenta/client/backend/types/{span_variant.py → legacy_user_response.py} +2 -4
  26. agenta/client/backend/types/organization_membership_request.py +25 -0
  27. agenta/client/backend/types/organization_request.py +23 -0
  28. agenta/client/backend/types/permission.py +4 -0
  29. agenta/client/backend/types/{llm_tokens.py → project_membership_request.py} +8 -5
  30. agenta/client/backend/types/project_request.py +26 -0
  31. agenta/client/backend/types/project_scope.py +29 -0
  32. agenta/client/backend/types/provider_kind.py +1 -1
  33. agenta/client/backend/types/reference.py +22 -0
  34. agenta/client/backend/types/role.py +15 -0
  35. agenta/client/backend/types/scopes_response_model.py +22 -0
  36. agenta/client/backend/types/score.py +1 -1
  37. agenta/client/backend/types/secret_response_dto.py +2 -2
  38. agenta/client/backend/types/user_request.py +22 -0
  39. agenta/client/backend/types/workspace_membership_request.py +26 -0
  40. agenta/client/backend/types/workspace_request.py +25 -0
  41. agenta/client/backend/variants/client.py +208 -42
  42. agenta/client/backend/vault/client.py +11 -9
  43. agenta/sdk/__init__.py +3 -0
  44. agenta/sdk/agenta_init.py +3 -1
  45. agenta/sdk/assets.py +4 -4
  46. agenta/sdk/decorators/routing.py +129 -23
  47. agenta/sdk/decorators/tracing.py +16 -4
  48. agenta/sdk/litellm/litellm.py +44 -8
  49. agenta/sdk/litellm/mockllm.py +2 -2
  50. agenta/sdk/litellm/mocks/__init__.py +9 -3
  51. agenta/sdk/managers/apps.py +64 -0
  52. agenta/sdk/managers/shared.py +2 -2
  53. agenta/sdk/middleware/auth.py +156 -53
  54. agenta/sdk/middleware/config.py +28 -16
  55. agenta/sdk/middleware/inline.py +1 -1
  56. agenta/sdk/middleware/mock.py +1 -1
  57. agenta/sdk/middleware/otel.py +1 -1
  58. agenta/sdk/middleware/vault.py +1 -1
  59. agenta/sdk/tracing/exporters.py +0 -1
  60. agenta/sdk/tracing/inline.py +26 -30
  61. agenta/sdk/types.py +12 -9
  62. {agenta-0.32.0a1.dist-info → agenta-0.33.0.dist-info}/METADATA +23 -20
  63. {agenta-0.32.0a1.dist-info → agenta-0.33.0.dist-info}/RECORD +69 -63
  64. agenta/client/backend/observability_v_1/__init__.py +0 -5
  65. agenta/client/backend/observability_v_1/client.py +0 -763
  66. agenta/client/backend/types/create_span.py +0 -45
  67. agenta/client/backend/types/human_evaluation_scenario_update.py +0 -30
  68. agenta/client/backend/types/new_human_evaluation.py +0 -27
  69. agenta/client/backend/types/outputs.py +0 -5
  70. agenta/client/backend/types/span.py +0 -42
  71. agenta/client/backend/types/span_detail.py +0 -44
  72. agenta/client/backend/types/span_status_code.py +0 -5
  73. agenta/client/backend/types/trace_detail.py +0 -44
  74. agenta/client/backend/types/with_pagination.py +0 -26
  75. /agenta/client/backend/{observability_v_1 → observability}/types/__init__.py +0 -0
  76. /agenta/client/backend/{observability_v_1 → observability}/types/format.py +0 -0
  77. /agenta/client/backend/{observability_v_1 → observability}/types/query_analytics_response.py +0 -0
  78. /agenta/client/backend/{observability_v_1 → observability}/types/query_traces_response.py +0 -0
  79. {agenta-0.32.0a1.dist-info → agenta-0.33.0.dist-info}/WHEEL +0 -0
  80. {agenta-0.32.0a1.dist-info → agenta-0.33.0.dist-info}/entry_points.txt +0 -0
@@ -11,16 +11,16 @@ from fastapi.responses import JSONResponse
11
11
  from agenta.sdk.middleware.cache import TTLLRUCache, CACHE_CAPACITY, CACHE_TTL
12
12
  from agenta.sdk.utils.constants import TRUTHY
13
13
  from agenta.sdk.utils.exceptions import display_exception
14
+ from agenta.sdk.utils.logging import log
14
15
 
15
16
  import agenta as ag
16
17
 
18
+ AGENTA_RUNTIME_PREFIX = getenv("AGENTA_RUNTIME_PREFIX", "")
17
19
 
18
- _SHARED_SERVICE = getenv("AGENTA_SHARED_SERVICE", "false").lower() in TRUTHY
19
- _CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in TRUTHY
20
- _UNAUTHORIZED_ALLOWED = (
21
- getenv("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", "false").lower() in TRUTHY
22
- )
23
- _ALWAYS_ALLOW_LIST = ["/health"]
20
+
21
+ _CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "false").lower() in TRUTHY
22
+
23
+ _ALWAYS_ALLOW_LIST = [f"{AGENTA_RUNTIME_PREFIX}/health"]
24
24
 
25
25
  _cache = TTLLRUCache(capacity=CACHE_CAPACITY, ttl=CACHE_TTL)
26
26
 
@@ -54,15 +54,49 @@ class AuthMiddleware(BaseHTTPMiddleware):
54
54
  super().__init__(app)
55
55
 
56
56
  self.host = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host
57
- self.resource_id = (
58
- ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.service_id
59
- if not _SHARED_SERVICE
60
- else None
61
- )
57
+ self.resource_id = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.service_id
62
58
 
63
59
  async def dispatch(self, request: Request, call_next: Callable):
60
+ # Extract request details
61
+ host = request.client.host if request.client else "unknown"
62
+ path = request.url.path
63
+ query = dict(request.query_params)
64
+ headers = dict(request.headers)
65
+
66
+ import logging
67
+ import json
68
+
69
+ # Log the request details
70
+ logging.error(
71
+ json.dumps(
72
+ {
73
+ "host": host,
74
+ "method": request.method,
75
+ "path": path,
76
+ "query_params": query,
77
+ "headers": headers,
78
+ },
79
+ indent=2,
80
+ ensure_ascii=False,
81
+ )
82
+ )
83
+
84
+ print(
85
+ json.dumps(
86
+ {
87
+ "host": host,
88
+ "method": request.method,
89
+ "path": path,
90
+ "query_params": query,
91
+ "headers": headers,
92
+ },
93
+ indent=2,
94
+ ensure_ascii=False,
95
+ )
96
+ )
97
+
64
98
  try:
65
- if _UNAUTHORIZED_ALLOWED or request.url.path in _ALWAYS_ALLOW_LIST:
99
+ if request.url.path in _ALWAYS_ALLOW_LIST:
66
100
  request.state.auth = {}
67
101
 
68
102
  else:
@@ -98,7 +132,10 @@ class AuthMiddleware(BaseHTTPMiddleware):
98
132
 
99
133
  cookies = {"sAccessToken": access_token} if access_token else None
100
134
 
101
- baggage = request.state.otel.get("baggage") if request.state.otel else {}
135
+ if not headers and not cookies:
136
+ log.debug("No auth header nor auth cookie found in the request")
137
+
138
+ baggage = request.state.otel["baggage"]
102
139
 
103
140
  project_id = (
104
141
  # CLEANEST
@@ -107,6 +144,9 @@ class AuthMiddleware(BaseHTTPMiddleware):
107
144
  or request.query_params.get("project_id")
108
145
  )
109
146
 
147
+ if not project_id:
148
+ log.debug("No project ID found in request")
149
+
110
150
  params = {"action": "run_service", "resource_type": "service"}
111
151
 
112
152
  if self.resource_id:
@@ -128,53 +168,116 @@ class AuthMiddleware(BaseHTTPMiddleware):
128
168
  credentials = _cache.get(_hash)
129
169
 
130
170
  if credentials:
171
+ log.debug("Using cached credentials")
131
172
  return credentials
132
173
 
133
- async with httpx.AsyncClient() as client:
134
- response = await client.get(
135
- f"{self.host}/api/permissions/verify",
136
- headers=headers,
137
- cookies=cookies,
138
- params=params,
139
- )
140
-
141
- if response.status_code == 401:
142
- raise DenyException(
143
- status_code=401,
144
- content="Invalid credentials",
145
- )
146
- elif response.status_code == 403:
147
- raise DenyException(
148
- status_code=403,
149
- content="Service execution not allowed.",
150
- )
151
- elif response.status_code != 200:
152
- raise DenyException(
153
- status_code=400,
154
- content="Auth: Unexpected Error.",
155
- )
156
-
157
- auth = response.json()
174
+ try:
175
+ async with httpx.AsyncClient() as client:
176
+ try:
177
+ response = await client.get(
178
+ f"{self.host}/api/permissions/verify",
179
+ headers=headers,
180
+ cookies=cookies,
181
+ params=params,
182
+ timeout=30.0,
183
+ )
184
+ except httpx.TimeoutException as exc:
185
+ log.debug(f"Timeout error while verify credentials: {exc}")
186
+ raise DenyException(
187
+ status_code=504,
188
+ content="Could not verify credentials: connection to {self.host} timed out. Please check your network connection.",
189
+ ) from exc
190
+ except httpx.ConnectError as exc:
191
+ log.debug(f"Connection error while verify credentials: {exc}")
192
+ raise DenyException(
193
+ status_code=503,
194
+ content=f"Could not verify credentials: connection to {self.host} failed. Please check if agenta is available.",
195
+ ) from exc
196
+ except httpx.NetworkError as exc:
197
+ log.debug(f"Network error while verify credentials: {exc}")
198
+ raise DenyException(
199
+ status_code=503,
200
+ content="Could not verify credentials: connection to {self.host} failed. Please check your network connection.",
201
+ ) from exc
202
+ except httpx.HTTPError as exc:
203
+ log.debug(f"HTTP error while verify credentials: {exc}")
204
+ raise DenyException(
205
+ status_code=502,
206
+ content=f"Could not verify credentials: connection to {self.host} failed. Please check if agenta is available.",
207
+ ) from exc
208
+
209
+ if response.status_code == 401:
210
+ log.debug("Agenta returned 401 - Invalid credentials")
211
+ raise DenyException(
212
+ status_code=401,
213
+ content="Invalid credentials. Please check your credentials or login again.",
214
+ )
215
+ elif response.status_code == 403:
216
+ log.debug("Agenta returned 403 - Permission denied")
217
+ raise DenyException(
218
+ status_code=403,
219
+ content="Permission denied. Please check your permissions or contact your administrator.",
220
+ )
221
+ elif response.status_code != 200:
222
+ log.debug(
223
+ f"Agenta returned {response.status_code} - Unexpected status code"
224
+ )
225
+ raise DenyException(
226
+ status_code=500,
227
+ content=f"Could no verify credentials: {self.host} returned unexpected status code {response.status_code}. Please try again later or contact support if the issue persists.",
228
+ )
229
+
230
+ try:
231
+ auth = response.json()
232
+ except ValueError as exc:
233
+ log.debug(f"Agenta returned invalid JSON response: {exc}")
234
+ raise DenyException(
235
+ status_code=500,
236
+ content=f"Could no verify credentials: {self.host} returned unexpected invalid JSON response. Please try again later or contact support if the issue persists.",
237
+ ) from exc
238
+
239
+ if not isinstance(auth, dict):
240
+ log.debug(
241
+ f"Agenta returned invalid response format: {type(auth)}"
242
+ )
243
+ raise DenyException(
244
+ status_code=500,
245
+ content=f"Could no verify credentials: {self.host} returned unexpected invalid response format. Please try again later or contact support if the issue persists.",
246
+ )
247
+
248
+ effect = auth.get("effect")
249
+ if effect != "allow":
250
+ log.debug("Access denied by Agenta - effect: {effect}")
251
+ raise DenyException(
252
+ status_code=403,
253
+ content="Permission denied. Please check your permissions or contact your administrator.",
254
+ )
255
+
256
+ credentials = auth.get("credentials")
257
+
258
+ if not credentials:
259
+ log.debug("No credentials found in the response")
260
+
261
+ _cache.put(_hash, credentials)
158
262
 
159
- if auth.get("effect") != "allow":
160
- raise DenyException(
161
- status_code=403,
162
- content="Service execution not allowed.",
163
- )
164
-
165
- credentials = auth.get("credentials")
166
-
167
- _cache.put(_hash, credentials)
263
+ return credentials
168
264
 
169
- return credentials
265
+ except DenyException as deny:
266
+ raise deny
267
+ except Exception as exc: # pylint: disable=bare-except
268
+ log.debug(
269
+ f"Unexpected error while verifying credentials (remote): {exc}"
270
+ )
271
+ raise DenyException(
272
+ status_code=500,
273
+ content=f"Could no verify credentials: unexpected error - {str(exc)}. Please try again later or contact support if the issue persists.",
274
+ ) from exc
170
275
 
171
276
  except DenyException as deny:
172
277
  raise deny
173
-
174
- except Exception as exc: # pylint: disable=bare-except
175
- display_exception("Auth Middleware Exception (suppressed)")
176
-
278
+ except Exception as exc:
279
+ log.debug(f"Unexpected error while verifying credentials (local): {exc}")
177
280
  raise DenyException(
178
281
  status_code=500,
179
- content="Auth: Unexpected Error.",
282
+ content=f"Could no verify credentials: unexpected error - {str(exc)}. Please try again later or contact support if the issue persists.",
180
283
  ) from exc
@@ -17,7 +17,7 @@ from agenta.sdk.utils.exceptions import suppress
17
17
  import agenta as ag
18
18
 
19
19
 
20
- _CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in TRUTHY
20
+ _CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "false").lower() in TRUTHY
21
21
 
22
22
  _cache = TTLLRUCache(capacity=CACHE_CAPACITY, ttl=CACHE_TTL)
23
23
 
@@ -92,7 +92,7 @@ class ConfigMiddleware(BaseHTTPMiddleware):
92
92
 
93
93
  return parameters, references
94
94
 
95
- config = None
95
+ config = {}
96
96
  async with httpx.AsyncClient() as client:
97
97
  response = await client.post(
98
98
  f"{self.host}/api/variants/configs/fetch",
@@ -100,21 +100,25 @@ class ConfigMiddleware(BaseHTTPMiddleware):
100
100
  json=refs,
101
101
  )
102
102
 
103
- if response.status_code != 200:
104
- return None, None
105
-
106
- config = response.json()
103
+ if response.status_code == 200:
104
+ config = response.json()
107
105
 
108
106
  if not config:
109
- _cache.put(_hash, {"parameters": None, "references": None})
107
+ config["application_ref"] = refs[
108
+ "application_ref"
109
+ ] # by default, application_ref will always have an id
110
+ parameters = None
111
+ else:
112
+ parameters = config.get("params")
110
113
 
111
- return None, None
114
+ references = {}
112
115
 
113
- parameters = config.get("params")
116
+ ref_keys = ["application_ref"]
114
117
 
115
- references = {}
118
+ if config:
119
+ ref_keys.extend(["variant_ref", "environment_ref"])
116
120
 
117
- for ref_key in ["application_ref", "variant_ref", "environment_ref"]:
121
+ for ref_key in ref_keys:
118
122
  refs = config.get(ref_key)
119
123
  if refs:
120
124
  ref_prefix = ref_key.split("_", maxsplit=1)[0]
@@ -123,7 +127,7 @@ class ConfigMiddleware(BaseHTTPMiddleware):
123
127
  ref_part = refs.get(ref_part_key)
124
128
 
125
129
  if ref_part:
126
- references[ref_prefix + "." + ref_part_key] = ref_part
130
+ references[ref_prefix + "." + ref_part_key] = str(ref_part)
127
131
 
128
132
  _cache.put(_hash, {"parameters": parameters, "references": references})
129
133
 
@@ -133,7 +137,7 @@ class ConfigMiddleware(BaseHTTPMiddleware):
133
137
  self,
134
138
  request: Request,
135
139
  ) -> Optional[Reference]:
136
- baggage = request.state.otel.get("baggage") if request.state.otel else {}
140
+ baggage = request.state.otel["baggage"]
137
141
 
138
142
  body = {}
139
143
  try:
@@ -160,19 +164,27 @@ class ConfigMiddleware(BaseHTTPMiddleware):
160
164
  or body.get("app")
161
165
  )
162
166
 
163
- if not any([application_id, application_slug]):
167
+ application_version = (
168
+ # CLEANEST
169
+ baggage.get("application_version")
170
+ # ALTERNATIVE
171
+ or request.query_params.get("application_version")
172
+ )
173
+
174
+ if not any([application_id, application_slug, application_version]):
164
175
  return None
165
176
 
166
177
  return Reference(
167
178
  id=application_id,
168
179
  slug=application_slug,
180
+ version=application_version,
169
181
  )
170
182
 
171
183
  async def _parse_variant_ref(
172
184
  self,
173
185
  request: Request,
174
186
  ) -> Optional[Reference]:
175
- baggage = request.state.otel.get("baggage") if request.state.otel else {}
187
+ baggage = request.state.otel["baggage"]
176
188
 
177
189
  body = {}
178
190
  try:
@@ -215,7 +227,7 @@ class ConfigMiddleware(BaseHTTPMiddleware):
215
227
  self,
216
228
  request: Request,
217
229
  ) -> Optional[Reference]:
218
- baggage = request.state.otel.get("baggage") if request.state.otel else {}
230
+ baggage = request.state.otel["baggage"]
219
231
 
220
232
  body = {}
221
233
  try:
@@ -21,7 +21,7 @@ class InlineMiddleware(BaseHTTPMiddleware):
21
21
  request.state.inline = False
22
22
 
23
23
  with suppress():
24
- baggage = request.state.otel.get("baggage") if request.state.otel else {}
24
+ baggage = request.state.otel["baggage"]
25
25
 
26
26
  inline = (
27
27
  str(
@@ -19,7 +19,7 @@ class MockMiddleware(BaseHTTPMiddleware):
19
19
  request.state.mock = None
20
20
 
21
21
  with suppress():
22
- baggage = request.state.otel.get("baggage") if request.state.otel else {}
22
+ baggage = request.state.otel["baggage"]
23
23
 
24
24
  mock = (
25
25
  # CLEANEST
@@ -13,7 +13,7 @@ class OTelMiddleware(BaseHTTPMiddleware):
13
13
  super().__init__(app)
14
14
 
15
15
  async def dispatch(self, request: Request, call_next: Callable):
16
- request.state.otel = {}
16
+ request.state.otel = {"baggage": {}}
17
17
 
18
18
  with suppress():
19
19
  baggage = await self._get_baggage(request)
@@ -24,7 +24,7 @@ for arg in ProviderKind.__args__: # type: ignore
24
24
  if hasattr(arg, "__args__"):
25
25
  _PROVIDER_KINDS.extend(arg.__args__)
26
26
 
27
- _CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in TRUTHY
27
+ _CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "false").lower() in TRUTHY
28
28
 
29
29
  _cache = TTLLRUCache(capacity=CACHE_CAPACITY, ttl=CACHE_TTL)
30
30
 
@@ -48,7 +48,6 @@ class InlineTraceExporter(SpanExporter):
48
48
  trace_id: int,
49
49
  ) -> bool:
50
50
  is_ready = trace_id in self._registry
51
-
52
51
  return is_ready
53
52
 
54
53
  def fetch(
@@ -701,15 +701,6 @@ def _parse_from_semconv(
701
701
  def _parse_from_links(
702
702
  otel_span_dto: OTelSpanDTO,
703
703
  ) -> dict:
704
- # TESTING
705
- otel_span_dto.links = [
706
- OTelLinkDTO(
707
- context=otel_span_dto.context,
708
- attributes={"ag.type.link": "testcase"},
709
- )
710
- ]
711
- # -------
712
-
713
704
  # LINKS
714
705
  links = None
715
706
  otel_links = None
@@ -926,8 +917,9 @@ def parse_to_agenta_span_dto(
926
917
  if span_dto.refs:
927
918
  span_dto.refs = _unmarshal_attributes(span_dto.refs)
928
919
 
929
- for link in span_dto.links:
930
- link.tree_id = None
920
+ if isinstance(span_dto.links, list):
921
+ for link in span_dto.links:
922
+ link.tree_id = None
931
923
 
932
924
  if span_dto.nodes:
933
925
  for v in span_dto.nodes.values():
@@ -1030,6 +1022,24 @@ def _parse_readable_spans(
1030
1022
  otel_span_dtos = list()
1031
1023
 
1032
1024
  for span in spans:
1025
+ otel_events = [
1026
+ OTelEventDTO(
1027
+ name=event.name,
1028
+ timestamp=_timestamp_ns_to_datetime(event.timestamp),
1029
+ attributes=event.attributes,
1030
+ )
1031
+ for event in span.events
1032
+ ]
1033
+ otel_links = [
1034
+ OTelLinkDTO(
1035
+ context=OTelContextDTO(
1036
+ trace_id=_int_to_hex(link.context.trace_id, 128),
1037
+ span_id=_int_to_hex(link.context.span_id, 64),
1038
+ ),
1039
+ attributes=link.attributes,
1040
+ )
1041
+ for link in span.links
1042
+ ]
1033
1043
  otel_span_dto = OTelSpanDTO(
1034
1044
  context=OTelContextDTO(
1035
1045
  trace_id=_int_to_hex(span.get_span_context().trace_id, 128),
@@ -1045,14 +1055,7 @@ def _parse_readable_spans(
1045
1055
  status_code=OTelStatusCode("STATUS_CODE_" + span.status.status_code.name),
1046
1056
  status_message=span.status.description,
1047
1057
  attributes=span.attributes,
1048
- events=[
1049
- OTelEventDTO(
1050
- name=event.name,
1051
- timestamp=_timestamp_ns_to_datetime(event.timestamp),
1052
- attributes=event.attributes,
1053
- )
1054
- for event in span.events
1055
- ],
1058
+ events=otel_events if len(otel_events) > 0 else None,
1056
1059
  parent=(
1057
1060
  OTelContextDTO(
1058
1061
  trace_id=_int_to_hex(span.parent.trace_id, 128),
@@ -1061,16 +1064,7 @@ def _parse_readable_spans(
1061
1064
  if span.parent
1062
1065
  else None
1063
1066
  ),
1064
- links=[
1065
- OTelLinkDTO(
1066
- context=OTelContextDTO(
1067
- trace_id=_int_to_hex(link.context.trace_id, 128),
1068
- span_id=_int_to_hex(link.context.span_id, 64),
1069
- ),
1070
- attributes=link.attributes,
1071
- )
1072
- for link in span.links
1073
- ],
1067
+ links=otel_links if len(otel_links) > 0 else None,
1074
1068
  )
1075
1069
 
1076
1070
  otel_span_dtos.append(otel_span_dto)
@@ -1121,7 +1115,9 @@ def calculate_costs(span_idx: Dict[str, SpanDTO]):
1121
1115
  and span.meta
1122
1116
  and span.metrics
1123
1117
  ):
1124
- model = span.meta.get("response.model")
1118
+ model = span.meta.get("response.model") or span.meta.get(
1119
+ "configuration.model"
1120
+ )
1125
1121
  prompt_tokens = span.metrics.get("unit.tokens.prompt", 0.0)
1126
1122
  completion_tokens = span.metrics.get("unit.tokens.completion", 0.0)
1127
1123
 
agenta/sdk/types.py CHANGED
@@ -7,7 +7,7 @@ from pydantic import ConfigDict, BaseModel, HttpUrl
7
7
  from agenta.client.backend.types.agenta_node_dto import AgentaNodeDto
8
8
  from agenta.client.backend.types.agenta_nodes_response import AgentaNodesResponse
9
9
  from typing import Annotated, List, Union, Optional, Dict, Literal, Any
10
- from pydantic import BaseModel, Field, root_validator
10
+ from pydantic import BaseModel, Field, model_validator
11
11
  from agenta.sdk.assets import supported_llm_models
12
12
 
13
13
 
@@ -36,12 +36,14 @@ class LLMTokenUsage(BaseModel):
36
36
 
37
37
 
38
38
  class BaseResponse(BaseModel):
39
- version: Optional[str] = "3.1"
39
+ version: Optional[str] = "3.0"
40
40
  data: Optional[Union[str, Dict[str, Any]]] = None
41
41
  content_type: Optional[str] = "string"
42
42
  tree: Optional[AgentaNodesResponse] = None
43
43
  tree_id: Optional[str] = None
44
44
 
45
+ model_config = ConfigDict(use_enum_values=True, exclude_none=True)
46
+
45
47
 
46
48
  class DictInput(dict):
47
49
  def __new__(cls, default_keys: Optional[List[str]] = None):
@@ -327,30 +329,31 @@ class ModelConfig(BaseModel):
327
329
  )
328
330
 
329
331
  temperature: Optional[float] = Field(
330
- default=1,
332
+ default=None,
331
333
  ge=0.0,
332
334
  le=2.0,
333
335
  description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic",
334
336
  )
335
337
  max_tokens: Optional[int] = Field(
336
- default=-1,
338
+ default=None,
337
339
  ge=0,
340
+ le=4000,
338
341
  description="The maximum number of tokens that can be generated in the chat completion",
339
342
  )
340
343
  top_p: Optional[float] = Field(
341
- default=0.5,
344
+ default=None,
342
345
  ge=0.0,
343
346
  le=1.0,
344
347
  description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass",
345
348
  )
346
349
  frequency_penalty: Optional[float] = Field(
347
- default=0,
350
+ default=None,
348
351
  ge=-2.0,
349
352
  le=2.0,
350
353
  description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far",
351
354
  )
352
355
  presence_penalty: Optional[float] = Field(
353
- default=0,
356
+ default=None,
354
357
  ge=-2.0,
355
358
  le=2.0,
356
359
  description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far",
@@ -405,7 +408,7 @@ class PromptTemplate(BaseModel):
405
408
  system_prompt: Optional[str] = None
406
409
  user_prompt: Optional[str] = None
407
410
  template_format: Literal["fstring", "jinja2", "curly"] = Field(
408
- default="fstring",
411
+ default="curly",
409
412
  description="Format type for template variables: fstring {var}, jinja2 {{ var }}, or curly {{var}}",
410
413
  )
411
414
  input_keys: Optional[List[str]] = Field(
@@ -425,7 +428,7 @@ class PromptTemplate(BaseModel):
425
428
  }
426
429
  }
427
430
 
428
- @root_validator(pre=True)
431
+ @model_validator(mode="before")
429
432
  def init_messages(cls, values):
430
433
  if "messages" not in values:
431
434
  messages = []