agentkeeper-runtime-sdk 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,314 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from typing import Any, AsyncIterator, Iterator, Mapping, Optional
5
+
6
+ from .client import (
7
+ AgentKeeperRuntimeClient,
8
+ compact,
9
+ create_agentkeeper_runtime_client,
10
+ get_value,
11
+ number_value,
12
+ safe_error_name,
13
+ string_value,
14
+ telemetry_metadata,
15
+ )
16
+
17
+
18
+ def _client_from(value: AgentKeeperRuntimeClient | Mapping[str, Any]) -> AgentKeeperRuntimeClient:
19
+ if isinstance(value, AgentKeeperRuntimeClient):
20
+ return value
21
+ return create_agentkeeper_runtime_client(**dict(value))
22
+
23
+
24
+ def _base_event(options: Mapping[str, Any]) -> dict[str, Any]:
25
+ return {
26
+ "runtime_integration": "azure_openai",
27
+ "runtime_service": options.get("runtime_service") or options.get("runtimeService"),
28
+ "runtime_environment": options.get("runtime_environment") or options.get("runtimeEnvironment"),
29
+ **dict(options.get("event") or {}),
30
+ }
31
+
32
+
33
+ def _request_from_args(args: tuple[Any, ...], kwargs: Mapping[str, Any]) -> dict[str, Any]:
34
+ if args and isinstance(args[0], Mapping):
35
+ return {**dict(args[0]), **dict(kwargs)}
36
+ return dict(kwargs)
37
+
38
+
39
+ def _tool_names(request: Mapping[str, Any]) -> list[str]:
40
+ tools = request.get("tools")
41
+ if not isinstance(tools, list):
42
+ return []
43
+ names: list[str] = []
44
+ for tool in tools[:50]:
45
+ function = get_value(tool, "function")
46
+ name = string_value(get_value(function, "name")) or string_value(get_value(tool, "name"))
47
+ if name:
48
+ names.append(name)
49
+ return names
50
+
51
+
52
+ def _model_name(request: Mapping[str, Any], response: Any = None) -> Optional[str]:
53
+ return (
54
+ string_value(get_value(response, "model"))
55
+ or string_value(request.get("model"))
56
+ or string_value(request.get("deployment"))
57
+ or string_value(request.get("deployment_name"))
58
+ or string_value(request.get("deploymentName"))
59
+ )
60
+
61
+
62
+ def _input_type(value: Any) -> Optional[str]:
63
+ if value is None:
64
+ return None
65
+ if isinstance(value, list):
66
+ return "array"
67
+ if isinstance(value, Mapping):
68
+ return "object"
69
+ if isinstance(value, str):
70
+ return "string"
71
+ if isinstance(value, bool):
72
+ return "boolean"
73
+ if isinstance(value, (int, float)):
74
+ return "number"
75
+ return type(value).__name__
76
+
77
+
78
+ def _request_metadata(request: Mapping[str, Any], operation: str, options: Mapping[str, Any]) -> dict[str, Any]:
79
+ ids = telemetry_metadata(
80
+ runtime_service=options.get("runtime_service") or options.get("runtimeService"),
81
+ runtime_environment=options.get("runtime_environment") or options.get("runtimeEnvironment"),
82
+ run_id=options.get("run_id") or options.get("runId"),
83
+ trace_id=options.get("trace_id") or options.get("traceId"),
84
+ span_id=options.get("span_id") or options.get("spanId"),
85
+ user_id=options.get("user_id") or options.get("userId"),
86
+ convo_id=options.get("convo_id") or options.get("convoId"),
87
+ )
88
+ tools = _tool_names(request)
89
+ response_format = request.get("response_format")
90
+ return compact({
91
+ **ids,
92
+ "operation": operation,
93
+ "model": _model_name(request),
94
+ "api_version": string_value(request.get("api_version")) or string_value(request.get("apiVersion")),
95
+ "message_count": len(request.get("messages")) if isinstance(request.get("messages"), list) else None,
96
+ "input_type": _input_type(request.get("input")),
97
+ "tool_count": len(tools) or None,
98
+ "tool_names": tools or None,
99
+ "response_format_type": string_value(get_value(response_format, "type")),
100
+ "max_tokens": number_value(request.get("max_tokens")) or number_value(request.get("maxTokens")),
101
+ "stream": request.get("stream") is True,
102
+ })
103
+
104
+
105
+ def _usage_fields(response: Any) -> dict[str, Any]:
106
+ usage = get_value(response, "usage") or {}
107
+ return compact({
108
+ "token_input": (
109
+ number_value(get_value(usage, "prompt_tokens"))
110
+ or number_value(get_value(usage, "promptTokens"))
111
+ or number_value(get_value(usage, "input_tokens"))
112
+ or number_value(get_value(usage, "inputTokens"))
113
+ ),
114
+ "token_output": (
115
+ number_value(get_value(usage, "completion_tokens"))
116
+ or number_value(get_value(usage, "completionTokens"))
117
+ or number_value(get_value(usage, "output_tokens"))
118
+ or number_value(get_value(usage, "outputTokens"))
119
+ ),
120
+ })
121
+
122
+
123
+ def _response_metadata(response: Any, metadata: Mapping[str, Any]) -> dict[str, Any]:
124
+ choices = get_value(response, "choices")
125
+ if not isinstance(choices, list):
126
+ choices = []
127
+ finish_reasons: list[str] = []
128
+ for choice in choices[:50]:
129
+ reason = string_value(get_value(choice, "finish_reason")) or string_value(get_value(choice, "finishReason"))
130
+ if reason:
131
+ finish_reasons.append(reason)
132
+ response_output = get_value(response, "output")
133
+ return compact({
134
+ **dict(metadata),
135
+ "response_id": string_value(get_value(response, "id")),
136
+ "response_object": string_value(get_value(response, "object")),
137
+ "choice_count": len(choices) or None,
138
+ "output_item_count": len(response_output) if isinstance(response_output, list) else None,
139
+ "finish_reasons": sorted(set(finish_reasons)) or None,
140
+ "stream_event_count": number_value(get_value(response, "event_count")),
141
+ })
142
+
143
+
144
+ class _StreamProxy:
145
+ def __init__(self, stream: Any, track_once: Any, model: Optional[str]) -> None:
146
+ self._stream = stream
147
+ self._track_once = track_once
148
+ self._model = model
149
+ self._tracked = False
150
+
151
+ def _track(self, response: Any, error: Optional[BaseException] = None) -> None:
152
+ if self._tracked:
153
+ return
154
+ self._tracked = True
155
+ self._track_once(response, error)
156
+
157
+ def __getattr__(self, name: str) -> Any:
158
+ return getattr(self._stream, name)
159
+
160
+ def __iter__(self) -> Iterator[Any]:
161
+ event_count = 0
162
+ try:
163
+ for event in self._stream:
164
+ event_count += 1
165
+ yield event
166
+ except BaseException as error:
167
+ self._track({"model": self._model, "event_count": event_count}, error)
168
+ raise
169
+ self._track({"model": self._model, "event_count": event_count})
170
+
171
+ async def __aiter__(self) -> AsyncIterator[Any]:
172
+ event_count = 0
173
+ try:
174
+ async for event in self._stream:
175
+ event_count += 1
176
+ yield event
177
+ except BaseException as error:
178
+ self._track({"model": self._model, "event_count": event_count}, error)
179
+ raise
180
+ self._track({"model": self._model, "event_count": event_count})
181
+
182
+ def __enter__(self) -> Any:
183
+ entered = self._stream.__enter__()
184
+ if entered is self._stream:
185
+ return self
186
+ self._stream = entered
187
+ return self
188
+
189
+ def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> Any:
190
+ return self._stream.__exit__(exc_type, exc, tb)
191
+
192
+ async def __aenter__(self) -> Any:
193
+ entered = await self._stream.__aenter__()
194
+ if entered is not self._stream:
195
+ self._stream = entered
196
+ return self
197
+
198
+ async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> Any:
199
+ return await self._stream.__aexit__(exc_type, exc, tb)
200
+
201
+
202
+ def _is_stream_like(result: Any, request: Mapping[str, Any]) -> bool:
203
+ if request.get("stream") is not True:
204
+ return False
205
+ return hasattr(result, "__aiter__") or hasattr(result, "__iter__")
206
+
207
+
208
+ class _AzureOpenAIResourceProxy:
209
+ def __init__(self, resource: Any, ak: AgentKeeperRuntimeClient, options: Mapping[str, Any], path: tuple[str, ...] = ()) -> None:
210
+ self._resource = resource
211
+ self._ak = ak
212
+ self._options = dict(options)
213
+ self._path = path
214
+
215
+ def __getattr__(self, name: str) -> Any:
216
+ value = getattr(self._resource, name)
217
+ operation = ".".join((*self._path, name))
218
+ if operation in {"chat.completions.create", "responses.create", "embeddings.create"} and callable(value):
219
+ return self._wrap_method(operation, value)
220
+ if value is not None and not isinstance(value, (str, bytes, bytearray, int, float, bool)):
221
+ return _AzureOpenAIResourceProxy(value, self._ak, self._options, (*self._path, name))
222
+ return value
223
+
224
+ def _track_start(self, request: Mapping[str, Any], operation: str, metadata: Mapping[str, Any]) -> None:
225
+ self._ak.track({
226
+ **_base_event(self._options),
227
+ "event_kind": "model_call",
228
+ "capability": "model_only",
229
+ "model_provider": "azure_openai",
230
+ "model_name": _model_name(request),
231
+ "run_id": metadata.get("run_id"),
232
+ "trace_id": metadata.get("trace_id"),
233
+ "span_id": metadata.get("span_id"),
234
+ "verdict": "observed",
235
+ "evidence_summary": f"Azure OpenAI {operation} observed",
236
+ "metadata": metadata,
237
+ })
238
+
239
+ def _track_completion(
240
+ self,
241
+ request: Mapping[str, Any],
242
+ operation: str,
243
+ metadata: Mapping[str, Any],
244
+ response: Any,
245
+ error: Optional[BaseException] = None,
246
+ ) -> None:
247
+ self._ak.track({
248
+ **_base_event(self._options),
249
+ "event_kind": "model_call",
250
+ "capability": "model_only",
251
+ "model_provider": "azure_openai",
252
+ "model_name": _model_name(request, response),
253
+ **_usage_fields(response),
254
+ "run_id": metadata.get("run_id"),
255
+ "trace_id": metadata.get("trace_id"),
256
+ "span_id": metadata.get("span_id"),
257
+ "verdict": "observed" if error else "passed",
258
+ "severity": "medium" if error else "low",
259
+ "evidence_summary": (
260
+ f"Azure OpenAI {operation} errored: {safe_error_name(error)}"
261
+ if error
262
+ else f"Azure OpenAI {operation} completed"
263
+ ),
264
+ "metadata": _response_metadata(response, metadata),
265
+ })
266
+
267
+ def _wrap_method(self, operation: str, method: Any) -> Any:
268
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
269
+ request = _request_from_args(args, kwargs)
270
+ metadata = _request_metadata(request, operation, self._options)
271
+ self._track_start(request, operation, metadata)
272
+ tracked = False
273
+
274
+ def track_once(response: Any, error: Optional[BaseException] = None) -> None:
275
+ nonlocal tracked
276
+ if tracked:
277
+ return
278
+ tracked = True
279
+ self._track_completion(request, operation, metadata, response, error)
280
+
281
+ def wrap_or_track(result: Any) -> Any:
282
+ if _is_stream_like(result, request):
283
+ return _StreamProxy(result, track_once, _model_name(request))
284
+ track_once(result)
285
+ return result
286
+
287
+ try:
288
+ result = method(*args, **kwargs)
289
+ except BaseException as error:
290
+ track_once({}, error)
291
+ raise
292
+
293
+ if inspect.isawaitable(result):
294
+ async def wait() -> Any:
295
+ try:
296
+ resolved = await result
297
+ except BaseException as error:
298
+ track_once({}, error)
299
+ raise
300
+ return wrap_or_track(resolved)
301
+
302
+ return wait()
303
+
304
+ return wrap_or_track(result)
305
+
306
+ return wrapped
307
+
308
+
309
+ def wrap_azure_openai_client(
310
+ azure_openai_client: Any,
311
+ client_or_options: AgentKeeperRuntimeClient | Mapping[str, Any],
312
+ **options: Any,
313
+ ) -> Any:
314
+ return _AzureOpenAIResourceProxy(azure_openai_client, _client_from(client_or_options), options)
@@ -0,0 +1,187 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import json
5
+ from typing import Any, Mapping, Optional
6
+
7
+ from .client import (
8
+ AgentKeeperRuntimeClient,
9
+ compact,
10
+ create_agentkeeper_runtime_client,
11
+ get_value,
12
+ number_value,
13
+ object_keys,
14
+ safe_error_name,
15
+ string_value,
16
+ telemetry_metadata,
17
+ )
18
+
19
+
20
+ def _client_from(value: AgentKeeperRuntimeClient | Mapping[str, Any]) -> AgentKeeperRuntimeClient:
21
+ if isinstance(value, AgentKeeperRuntimeClient):
22
+ return value
23
+ return create_agentkeeper_runtime_client(**dict(value))
24
+
25
+
26
+ def _body_json(output: Any) -> Optional[dict[str, Any]]:
27
+ body = get_value(output, "body")
28
+ if body is None:
29
+ return None
30
+ if isinstance(body, str):
31
+ text = body
32
+ elif isinstance(body, (bytes, bytearray, memoryview)):
33
+ text = bytes(body).decode("utf-8", errors="replace")
34
+ else:
35
+ return None
36
+ try:
37
+ parsed = json.loads(text)
38
+ except json.JSONDecodeError:
39
+ return None
40
+ return parsed if isinstance(parsed, dict) else None
41
+
42
+
43
+ def _usage_fields(output: Any) -> dict[str, Any]:
44
+ parsed = _body_json(output) or {}
45
+ usage = get_value(output, "usage") or parsed.get("usage") or parsed.get("amazonBedrockInvocationMetrics") or {}
46
+ return compact({
47
+ "token_input": (
48
+ number_value(get_value(usage, "inputTokens"))
49
+ or number_value(get_value(usage, "inputTokenCount"))
50
+ or number_value(get_value(usage, "input_tokens"))
51
+ ),
52
+ "token_output": (
53
+ number_value(get_value(usage, "outputTokens"))
54
+ or number_value(get_value(usage, "outputTokenCount"))
55
+ or number_value(get_value(usage, "output_tokens"))
56
+ ),
57
+ })
58
+
59
+
60
+ def _model_name(params: Mapping[str, Any]) -> Optional[str]:
61
+ return string_value(params.get("modelId")) or string_value(params.get("modelIdentifier"))
62
+
63
+
64
+ def _stop_reason(output: Any) -> Optional[str]:
65
+ parsed = _body_json(output) or {}
66
+ return (
67
+ string_value(get_value(output, "stopReason"))
68
+ or string_value(parsed.get("stopReason"))
69
+ or string_value(parsed.get("stop_reason"))
70
+ )
71
+
72
+
73
+ class _BedrockRuntimeClientProxy:
74
+ def __init__(self, bedrock_client: Any, ak: AgentKeeperRuntimeClient, options: Optional[Mapping[str, Any]] = None) -> None:
75
+ self._bedrock_client = bedrock_client
76
+ self._ak = ak
77
+ self._options = dict(options or {})
78
+
79
+ def __getattr__(self, name: str) -> Any:
80
+ value = getattr(self._bedrock_client, name)
81
+ if name in {"converse", "invoke_model", "converse_stream", "invoke_model_with_response_stream"} and callable(value):
82
+ return self._wrap_method(name, value)
83
+ return value
84
+
85
+ def _base_event(self) -> dict[str, Any]:
86
+ return {
87
+ "runtime_integration": "bedrock",
88
+ "runtime_service": self._options.get("runtime_service") or self._options.get("runtimeService"),
89
+ "runtime_environment": self._options.get("runtime_environment") or self._options.get("runtimeEnvironment"),
90
+ **dict(self._options.get("event") or {}),
91
+ }
92
+
93
+ def _track_start(self, operation: str, params: Mapping[str, Any], metadata: Mapping[str, Any]) -> None:
94
+ self._ak.track({
95
+ **self._base_event(),
96
+ "event_kind": "model_call",
97
+ "capability": "model_only",
98
+ "model_provider": "aws_bedrock",
99
+ "model_name": _model_name(params),
100
+ "run_id": metadata.get("run_id"),
101
+ "trace_id": metadata.get("trace_id"),
102
+ "span_id": metadata.get("span_id"),
103
+ "verdict": "observed",
104
+ "evidence_summary": f"AWS Bedrock Runtime {operation} observed",
105
+ "metadata": metadata,
106
+ })
107
+
108
+ def _track_completion(
109
+ self,
110
+ operation: str,
111
+ params: Mapping[str, Any],
112
+ metadata: Mapping[str, Any],
113
+ output: Any,
114
+ error: Optional[BaseException] = None,
115
+ ) -> None:
116
+ self._ak.track({
117
+ **self._base_event(),
118
+ "event_kind": "model_call",
119
+ "capability": "model_only",
120
+ "model_provider": "aws_bedrock",
121
+ "model_name": _model_name(params),
122
+ **_usage_fields(output),
123
+ "run_id": metadata.get("run_id"),
124
+ "trace_id": metadata.get("trace_id"),
125
+ "span_id": metadata.get("span_id"),
126
+ "verdict": "observed" if error else "passed",
127
+ "severity": "medium" if error else "low",
128
+ "evidence_summary": (
129
+ f"AWS Bedrock Runtime {operation} errored: {safe_error_name(error)}"
130
+ if error
131
+ else f"AWS Bedrock Runtime {operation} completed"
132
+ ),
133
+ "metadata": compact({
134
+ **dict(metadata),
135
+ "operation": operation,
136
+ "request_keys": object_keys(params),
137
+ "stop_reason": _stop_reason(output),
138
+ "response_has_body": get_value(output, "body") is not None,
139
+ }),
140
+ })
141
+
142
+ def _wrap_method(self, operation: str, method: Any) -> Any:
143
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
144
+ params = dict(kwargs)
145
+ if args and isinstance(args[0], Mapping):
146
+ params = {**dict(args[0]), **params}
147
+ metadata = compact({
148
+ **telemetry_metadata(
149
+ runtime_service=self._options.get("runtime_service") or self._options.get("runtimeService"),
150
+ runtime_environment=self._options.get("runtime_environment") or self._options.get("runtimeEnvironment"),
151
+ run_id=self._options.get("run_id") or self._options.get("runId"),
152
+ trace_id=self._options.get("trace_id") or self._options.get("traceId"),
153
+ span_id=self._options.get("span_id") or self._options.get("spanId"),
154
+ ),
155
+ "operation": operation,
156
+ "request_keys": object_keys(params),
157
+ })
158
+ self._track_start(operation, params, metadata)
159
+ try:
160
+ result = method(*args, **kwargs)
161
+ except BaseException as error:
162
+ self._track_completion(operation, params, metadata, {}, error)
163
+ raise
164
+ if inspect.isawaitable(result):
165
+ async def wait() -> Any:
166
+ try:
167
+ resolved = await result
168
+ except BaseException as error:
169
+ self._track_completion(operation, params, metadata, {}, error)
170
+ raise
171
+ self._track_completion(operation, params, metadata, resolved)
172
+ return resolved
173
+
174
+ return wait()
175
+ self._track_completion(operation, params, metadata, result)
176
+ return result
177
+
178
+ return wrapped
179
+
180
+
181
+ def wrap_bedrock_runtime_client(
182
+ bedrock_client: Any,
183
+ client_or_options: AgentKeeperRuntimeClient | Mapping[str, Any],
184
+ options: Optional[Mapping[str, Any]] = None,
185
+ ) -> Any:
186
+ return _BedrockRuntimeClientProxy(bedrock_client, _client_from(client_or_options), options)
187
+