tuningengines-cli 0.3.6 → 0.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +191 -5
- package/dist/cli.js +11 -1
- package/dist/cli.js.map +1 -1
- package/dist/client.d.ts +84 -0
- package/dist/client.d.ts.map +1 -1
- package/dist/client.js +138 -0
- package/dist/client.js.map +1 -1
- package/dist/commands/agents.d.ts +4 -0
- package/dist/commands/agents.d.ts.map +1 -0
- package/dist/commands/agents.js +72 -0
- package/dist/commands/agents.js.map +1 -0
- package/dist/commands/datasets.d.ts +4 -0
- package/dist/commands/datasets.d.ts.map +1 -0
- package/dist/commands/datasets.js +154 -0
- package/dist/commands/datasets.js.map +1 -0
- package/dist/commands/evaluations.d.ts +4 -0
- package/dist/commands/evaluations.d.ts.map +1 -0
- package/dist/commands/evaluations.js +168 -0
- package/dist/commands/evaluations.js.map +1 -0
- package/dist/commands/inference.d.ts +4 -0
- package/dist/commands/inference.d.ts.map +1 -0
- package/dist/commands/inference.js +92 -0
- package/dist/commands/inference.js.map +1 -0
- package/dist/commands/tenant.d.ts +4 -0
- package/dist/commands/tenant.d.ts.map +1 -0
- package/dist/commands/tenant.js +326 -0
- package/dist/commands/tenant.js.map +1 -0
- package/dist/mcp.d.ts.map +1 -1
- package/dist/mcp.js +349 -27
- package/dist/mcp.js.map +1 -1
- package/package.json +3 -2
- package/packages/tuning-agents/README.md +168 -0
- package/packages/tuning-agents/examples/langgraph_agent.py +25 -0
- package/packages/tuning-agents/examples/temporal_worker.py +27 -0
- package/packages/tuning-agents/pyproject.toml +38 -0
- package/packages/tuning-agents/tests/test_mcp.py +40 -0
- package/packages/tuning-agents/tests/test_trace.py +14 -0
- package/packages/tuning-agents/tuning_agents/__init__.py +11 -0
- package/packages/tuning-agents/tuning_agents/client.py +275 -0
- package/packages/tuning-agents/tuning_agents/langgraph.py +108 -0
- package/packages/tuning-agents/tuning_agents/mcp.py +208 -0
- package/packages/tuning-agents/tuning_agents/resources.py +26 -0
- package/packages/tuning-agents/tuning_agents/temporal.py +172 -0
- package/packages/tuning-agents/tuning_agents/trace.py +88 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from .client import TuningClient
|
|
6
|
+
from .mcp import make_agent_langchain_tools, make_langchain_tools
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def create_tuning_langgraph_agent(
|
|
10
|
+
client: TuningClient,
|
|
11
|
+
*,
|
|
12
|
+
model: str = "auto",
|
|
13
|
+
prompt: str | None = None,
|
|
14
|
+
server_names: set[str] | None = None,
|
|
15
|
+
tool_names: set[str] | None = None,
|
|
16
|
+
agent_names: list[str] | set[str] | None = None,
|
|
17
|
+
agent_descriptions: dict[str, str] | None = None,
|
|
18
|
+
checkpointer: Any | None = None,
|
|
19
|
+
interrupt_before: list[str] | None = None,
|
|
20
|
+
**agent_kwargs: Any,
|
|
21
|
+
) -> Any:
|
|
22
|
+
"""Create a LangGraph ReAct agent backed by Tuning Engines.
|
|
23
|
+
|
|
24
|
+
This gives callers a real agent loop while Tuning Engines remains the
|
|
25
|
+
governed gateway for models, MCP tool execution, RBAC, policy, audit,
|
|
26
|
+
request capture, and token economics.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
from langchain_openai import ChatOpenAI
|
|
31
|
+
from langgraph.prebuilt import create_react_agent
|
|
32
|
+
except ImportError as exc: # pragma: no cover - depends on optional extra
|
|
33
|
+
raise ImportError("Install tuning-agents[langgraph] to use the LangGraph adapter") from exc
|
|
34
|
+
|
|
35
|
+
llm = ChatOpenAI(
|
|
36
|
+
model=model,
|
|
37
|
+
api_key=client.api_key,
|
|
38
|
+
base_url=client.inference_url,
|
|
39
|
+
timeout=client.timeout,
|
|
40
|
+
)
|
|
41
|
+
tools = make_langchain_tools(client, server_names=server_names, tool_names=tool_names)
|
|
42
|
+
if agent_names:
|
|
43
|
+
tools.extend(
|
|
44
|
+
make_agent_langchain_tools(
|
|
45
|
+
client,
|
|
46
|
+
agent_names=agent_names,
|
|
47
|
+
descriptions=agent_descriptions,
|
|
48
|
+
)
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
kwargs: dict[str, Any] = dict(agent_kwargs)
|
|
52
|
+
if checkpointer is not None:
|
|
53
|
+
kwargs["checkpointer"] = checkpointer
|
|
54
|
+
if interrupt_before is not None:
|
|
55
|
+
kwargs["interrupt_before"] = interrupt_before
|
|
56
|
+
if prompt is not None:
|
|
57
|
+
kwargs["prompt"] = prompt
|
|
58
|
+
|
|
59
|
+
span_id = client.trace.start(
|
|
60
|
+
"langgraph.agent.create",
|
|
61
|
+
{"model": model, "tools": [getattr(tool, "name", None) for tool in tools]},
|
|
62
|
+
)
|
|
63
|
+
try:
|
|
64
|
+
agent = create_react_agent(llm, tools, **kwargs)
|
|
65
|
+
client.trace.finish(span_id, {"tool_count": len(tools)})
|
|
66
|
+
return agent
|
|
67
|
+
except Exception as exc:
|
|
68
|
+
client.trace.error(span_id, exc)
|
|
69
|
+
raise
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def invoke_with_trace(
|
|
73
|
+
client: TuningClient,
|
|
74
|
+
agent: Any,
|
|
75
|
+
messages: list[dict[str, str]],
|
|
76
|
+
*,
|
|
77
|
+
thread_id: str | None = None,
|
|
78
|
+
**config: Any,
|
|
79
|
+
) -> Any:
|
|
80
|
+
span_id = client.trace.start("langgraph.agent.invoke", {"thread_id": thread_id})
|
|
81
|
+
try:
|
|
82
|
+
runnable_config = dict(config)
|
|
83
|
+
if thread_id:
|
|
84
|
+
runnable_config.setdefault("configurable", {})["thread_id"] = thread_id
|
|
85
|
+
result = agent.invoke({"messages": messages}, runnable_config or None)
|
|
86
|
+
client.trace.finish(span_id, {"thread_id": thread_id})
|
|
87
|
+
return result
|
|
88
|
+
except Exception as exc:
|
|
89
|
+
client.trace.error(span_id, exc)
|
|
90
|
+
raise
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def invoke_and_flush_trace(
|
|
94
|
+
client: TuningClient,
|
|
95
|
+
agent: Any,
|
|
96
|
+
messages: list[dict[str, str]],
|
|
97
|
+
*,
|
|
98
|
+
thread_id: str | None = None,
|
|
99
|
+
name: str | None = None,
|
|
100
|
+
**config: Any,
|
|
101
|
+
) -> Any:
|
|
102
|
+
try:
|
|
103
|
+
result = invoke_with_trace(client, agent, messages, thread_id=thread_id, **config)
|
|
104
|
+
client.flush_trace(name=name, runtime="langgraph", status="succeeded", metadata={"thread_id": thread_id})
|
|
105
|
+
return result
|
|
106
|
+
except Exception:
|
|
107
|
+
client.flush_trace(name=name, runtime="langgraph", status="failed", metadata={"thread_id": thread_id})
|
|
108
|
+
raise
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field, create_model
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from .client import TuningClient
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def normalize_mcp_tools(payload: Any) -> list[dict[str, Any]]:
|
|
13
|
+
"""Normalize common MCP tool-list response shapes.
|
|
14
|
+
|
|
15
|
+
The proxy may return `{"tools": [...]}` or group tools by server. This
|
|
16
|
+
function keeps the adapter tolerant while preserving server attribution.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
if isinstance(payload, dict) and isinstance(payload.get("tools"), list):
|
|
20
|
+
return [dict(tool) for tool in payload["tools"]]
|
|
21
|
+
if isinstance(payload, list):
|
|
22
|
+
return [dict(tool) for tool in payload]
|
|
23
|
+
|
|
24
|
+
tools: list[dict[str, Any]] = []
|
|
25
|
+
if isinstance(payload, dict):
|
|
26
|
+
for server in payload.get("servers", []) or []:
|
|
27
|
+
server_name = server.get("name") or server.get("server_name")
|
|
28
|
+
for tool in server.get("tools", []) or []:
|
|
29
|
+
merged = dict(tool)
|
|
30
|
+
merged.setdefault("server_name", server_name)
|
|
31
|
+
tools.append(merged)
|
|
32
|
+
return tools
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def pydantic_model_from_json_schema(name: str, schema: Mapping[str, Any] | None) -> type[BaseModel]:
|
|
36
|
+
properties = dict((schema or {}).get("properties") or {})
|
|
37
|
+
required = set((schema or {}).get("required") or [])
|
|
38
|
+
fields: dict[str, tuple[Any, Any]] = {}
|
|
39
|
+
|
|
40
|
+
for field_name, spec in properties.items():
|
|
41
|
+
spec = spec or {}
|
|
42
|
+
annotation = _json_type_to_python(spec.get("type"))
|
|
43
|
+
default = ... if field_name in required else None
|
|
44
|
+
fields[field_name] = (
|
|
45
|
+
annotation,
|
|
46
|
+
Field(default, description=spec.get("description")),
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if not fields:
|
|
50
|
+
fields["arguments"] = (
|
|
51
|
+
dict[str, Any],
|
|
52
|
+
Field(default_factory=dict, description="Tool arguments as a JSON object."),
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
return create_model(name, **fields)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def make_langchain_tools(
|
|
59
|
+
client: TuningClient,
|
|
60
|
+
*,
|
|
61
|
+
server_names: set[str] | None = None,
|
|
62
|
+
tool_names: set[str] | None = None,
|
|
63
|
+
) -> list[Any]:
|
|
64
|
+
try:
|
|
65
|
+
from langchain_core.tools import StructuredTool
|
|
66
|
+
except ImportError as exc: # pragma: no cover - depends on optional extra
|
|
67
|
+
raise ImportError("Install tuning-agents[langgraph] to build LangGraph tools") from exc
|
|
68
|
+
|
|
69
|
+
raw_tools = normalize_mcp_tools(client.list_mcp_tools())
|
|
70
|
+
tools: list[Any] = []
|
|
71
|
+
for raw in raw_tools:
|
|
72
|
+
server_name = raw.get("server_name") or raw.get("server") or raw.get("mcp_server")
|
|
73
|
+
tool_name = raw.get("name") or raw.get("tool_name")
|
|
74
|
+
if not server_name or not tool_name:
|
|
75
|
+
continue
|
|
76
|
+
if server_names and server_name not in server_names:
|
|
77
|
+
continue
|
|
78
|
+
if tool_names and tool_name not in tool_names:
|
|
79
|
+
continue
|
|
80
|
+
|
|
81
|
+
args_schema = pydantic_model_from_json_schema(
|
|
82
|
+
f"{_safe_identifier(server_name)}_{_safe_identifier(tool_name)}_Args",
|
|
83
|
+
raw.get("inputSchema") or raw.get("input_schema") or raw.get("schema"),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def _call(_server_name: str = server_name, _tool_name: str = tool_name, **kwargs: Any) -> Any:
|
|
87
|
+
if "arguments" in kwargs and len(kwargs) == 1 and isinstance(kwargs["arguments"], dict):
|
|
88
|
+
kwargs = kwargs["arguments"]
|
|
89
|
+
return client.call_mcp_tool(
|
|
90
|
+
server_name=_server_name,
|
|
91
|
+
tool_name=_tool_name,
|
|
92
|
+
arguments=kwargs,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
tools.append(
|
|
96
|
+
StructuredTool.from_function(
|
|
97
|
+
func=_call,
|
|
98
|
+
name=f"{_safe_identifier(server_name)}__{_safe_identifier(tool_name)}",
|
|
99
|
+
description=raw.get("description") or f"Call {tool_name} on MCP server {server_name}.",
|
|
100
|
+
args_schema=args_schema,
|
|
101
|
+
)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
return tools
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def make_agent_langchain_tools(
|
|
108
|
+
client: "TuningClient",
|
|
109
|
+
*,
|
|
110
|
+
agent_names: list[str] | set[str],
|
|
111
|
+
descriptions: Mapping[str, str] | None = None,
|
|
112
|
+
) -> list[Any]:
|
|
113
|
+
try:
|
|
114
|
+
from langchain_core.tools import StructuredTool
|
|
115
|
+
except ImportError as exc: # pragma: no cover - depends on optional extra
|
|
116
|
+
raise ImportError("Install tuning-agents[langgraph] to build LangGraph tools") from exc
|
|
117
|
+
|
|
118
|
+
class AgentArgs(BaseModel):
|
|
119
|
+
message: str = Field(..., description="Task or message to send to the registered agent.")
|
|
120
|
+
context: dict[str, Any] = Field(default_factory=dict, description="Optional structured context.")
|
|
121
|
+
|
|
122
|
+
tools: list[Any] = []
|
|
123
|
+
for agent_name in agent_names:
|
|
124
|
+
safe_name = _safe_identifier(agent_name)
|
|
125
|
+
|
|
126
|
+
def _call(message: str, context: dict[str, Any] | None = None, _agent_name: str = agent_name) -> Any:
|
|
127
|
+
return client.call_agent(agent_name=_agent_name, message=message, context=context or {})
|
|
128
|
+
|
|
129
|
+
tools.append(
|
|
130
|
+
StructuredTool.from_function(
|
|
131
|
+
func=_call,
|
|
132
|
+
name=f"agent__{safe_name}",
|
|
133
|
+
description=(descriptions or {}).get(agent_name) or f"Delegate a full task to agent {agent_name}.",
|
|
134
|
+
args_schema=AgentArgs,
|
|
135
|
+
)
|
|
136
|
+
)
|
|
137
|
+
return tools
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def skill_tool_spec(
|
|
141
|
+
name: str,
|
|
142
|
+
*,
|
|
143
|
+
description: str | None = None,
|
|
144
|
+
parameters: Mapping[str, Any] | None = None,
|
|
145
|
+
) -> dict[str, Any]:
|
|
146
|
+
"""Build an OpenAI tool spec for a registered skill.
|
|
147
|
+
|
|
148
|
+
Skills are enforced by the proxy orchestrator when their function name
|
|
149
|
+
matches the registered tenant/platform skill name. Unlike MCP and A2A
|
|
150
|
+
agents, this SDK does not invent a local execution endpoint for skills.
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
return {
|
|
154
|
+
"type": "function",
|
|
155
|
+
"function": {
|
|
156
|
+
"name": name,
|
|
157
|
+
"description": description or f"Invoke the governed skill {name}.",
|
|
158
|
+
"parameters": parameters
|
|
159
|
+
or {
|
|
160
|
+
"type": "object",
|
|
161
|
+
"properties": {
|
|
162
|
+
"input": {"type": "string", "description": "Input for the skill."},
|
|
163
|
+
"context": {"type": "object", "description": "Optional structured context."},
|
|
164
|
+
},
|
|
165
|
+
"required": ["input"],
|
|
166
|
+
},
|
|
167
|
+
},
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def agent_tool_spec(
|
|
172
|
+
name: str,
|
|
173
|
+
*,
|
|
174
|
+
description: str | None = None,
|
|
175
|
+
) -> dict[str, Any]:
|
|
176
|
+
return {
|
|
177
|
+
"type": "function",
|
|
178
|
+
"function": {
|
|
179
|
+
"name": name,
|
|
180
|
+
"description": description or f"Delegate a full task to registered agent {name}.",
|
|
181
|
+
"parameters": {
|
|
182
|
+
"type": "object",
|
|
183
|
+
"properties": {
|
|
184
|
+
"message": {"type": "string", "description": "Message or task for the agent."},
|
|
185
|
+
"context": {"type": "object", "description": "Optional structured context."},
|
|
186
|
+
},
|
|
187
|
+
"required": ["message"],
|
|
188
|
+
},
|
|
189
|
+
},
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _json_type_to_python(json_type: Any) -> Any:
|
|
194
|
+
if isinstance(json_type, list):
|
|
195
|
+
json_type = next((item for item in json_type if item != "null"), None)
|
|
196
|
+
return {
|
|
197
|
+
"string": str,
|
|
198
|
+
"integer": int,
|
|
199
|
+
"number": float,
|
|
200
|
+
"boolean": bool,
|
|
201
|
+
"array": list[Any],
|
|
202
|
+
"object": dict[str, Any],
|
|
203
|
+
}.get(json_type, Any)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _safe_identifier(value: str) -> str:
|
|
207
|
+
cleaned = "".join(ch if ch.isalnum() else "_" for ch in value)
|
|
208
|
+
return cleaned.strip("_") or "tool"
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from .mcp import agent_tool_spec, skill_tool_spec
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(slots=True)
|
|
10
|
+
class ResourceManifest:
|
|
11
|
+
"""Runtime-facing manifest for governed Tuning Engines resources.
|
|
12
|
+
|
|
13
|
+
Use this when the application knows which agents/skills should be exposed
|
|
14
|
+
to a run. The proxy still enforces RBAC/AGT using the tenant registry.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
model: str = "auto"
|
|
18
|
+
agents: dict[str, str] = field(default_factory=dict)
|
|
19
|
+
skills: dict[str, str] = field(default_factory=dict)
|
|
20
|
+
extra_tools: list[dict[str, Any]] = field(default_factory=list)
|
|
21
|
+
|
|
22
|
+
def openai_tools(self) -> list[dict[str, Any]]:
|
|
23
|
+
tools = list(self.extra_tools)
|
|
24
|
+
tools.extend(agent_tool_spec(name, description=description) for name, description in self.agents.items())
|
|
25
|
+
tools.extend(skill_tool_spec(name, description=description) for name, description in self.skills.items())
|
|
26
|
+
return tools
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from datetime import timedelta
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from .client import TuningClient
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class AgentRunInput:
|
|
13
|
+
api_key: str
|
|
14
|
+
inference_url: str = "https://api.tuningengines.com/v1"
|
|
15
|
+
api_url: str = "https://app.tuningengines.com"
|
|
16
|
+
model: str = "auto"
|
|
17
|
+
messages: list[dict[str, Any]] = field(default_factory=list)
|
|
18
|
+
tools: list[dict[str, Any]] | None = None
|
|
19
|
+
max_steps: int = 8
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class AgentRunResult:
|
|
24
|
+
output: Any
|
|
25
|
+
trace: dict[str, Any]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
async def chat_completion_activity(payload: dict[str, Any]) -> dict[str, Any]:
|
|
29
|
+
client = _client_from_payload(payload)
|
|
30
|
+
response = client.chat(
|
|
31
|
+
model=payload["model"],
|
|
32
|
+
messages=payload["messages"],
|
|
33
|
+
tools=payload.get("tools"),
|
|
34
|
+
tool_choice=payload.get("tool_choice", "auto") if payload.get("tools") else None,
|
|
35
|
+
)
|
|
36
|
+
return response.model_dump(mode="json") if hasattr(response, "model_dump") else response
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
async def mcp_tool_activity(payload: dict[str, Any]) -> dict[str, Any]:
|
|
40
|
+
client = _client_from_payload(payload)
|
|
41
|
+
result = await client.acall_mcp_tool(
|
|
42
|
+
server_name=payload["server_name"],
|
|
43
|
+
tool_name=payload["tool_name"],
|
|
44
|
+
arguments=payload.get("arguments") or {},
|
|
45
|
+
)
|
|
46
|
+
return {"result": result, "trace": client.trace.as_dict()}
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
async def agent_message_activity(payload: dict[str, Any]) -> dict[str, Any]:
|
|
50
|
+
client = _client_from_payload(payload)
|
|
51
|
+
result = await client.acall_agent(
|
|
52
|
+
agent_name=payload["agent_name"],
|
|
53
|
+
message=payload["message"],
|
|
54
|
+
context=payload.get("context") or {},
|
|
55
|
+
)
|
|
56
|
+
return {"result": result, "trace": client.trace.as_dict()}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def define_temporal_workflow() -> type[Any]:
|
|
60
|
+
"""Return a Temporal workflow class without importing Temporal at module load.
|
|
61
|
+
|
|
62
|
+
Temporal workflow modules are replayed under deterministic constraints. This
|
|
63
|
+
factory keeps optional dependencies isolated for users who only need
|
|
64
|
+
LangGraph.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
try:
|
|
68
|
+
from temporalio import workflow
|
|
69
|
+
except ImportError as exc: # pragma: no cover - depends on optional extra
|
|
70
|
+
raise ImportError("Install tuning-agents[temporal] to use the Temporal adapter") from exc
|
|
71
|
+
|
|
72
|
+
@workflow.defn
|
|
73
|
+
class TuningAgentWorkflow:
|
|
74
|
+
@workflow.run
|
|
75
|
+
async def run(self, request: AgentRunInput) -> AgentRunResult:
|
|
76
|
+
messages = list(request.messages)
|
|
77
|
+
trace_events: list[dict[str, Any]] = []
|
|
78
|
+
|
|
79
|
+
for step in range(request.max_steps):
|
|
80
|
+
llm = await workflow.execute_activity(
|
|
81
|
+
chat_completion_activity,
|
|
82
|
+
{
|
|
83
|
+
"api_key": request.api_key,
|
|
84
|
+
"api_url": request.api_url,
|
|
85
|
+
"inference_url": request.inference_url,
|
|
86
|
+
"model": request.model,
|
|
87
|
+
"messages": messages,
|
|
88
|
+
"tools": request.tools,
|
|
89
|
+
},
|
|
90
|
+
start_to_close_timeout=timedelta(minutes=5),
|
|
91
|
+
)
|
|
92
|
+
trace_events.extend(_events_from_activity(llm))
|
|
93
|
+
choice = (llm.get("choices") or [{}])[0]
|
|
94
|
+
message = choice.get("message") or {}
|
|
95
|
+
messages.append(message)
|
|
96
|
+
|
|
97
|
+
tool_calls = message.get("tool_calls") or []
|
|
98
|
+
if not tool_calls:
|
|
99
|
+
return AgentRunResult(output=message, trace={"events": trace_events})
|
|
100
|
+
|
|
101
|
+
for tool_call in tool_calls:
|
|
102
|
+
function = tool_call.get("function") or {}
|
|
103
|
+
tool_call_name = function.get("name") or ""
|
|
104
|
+
arguments = function.get("arguments") or {}
|
|
105
|
+
if isinstance(arguments, str):
|
|
106
|
+
arguments = json.loads(arguments or "{}")
|
|
107
|
+
if tool_call_name.startswith("agent__"):
|
|
108
|
+
agent_name = tool_call_name.removeprefix("agent__")
|
|
109
|
+
tool_result = await workflow.execute_activity(
|
|
110
|
+
agent_message_activity,
|
|
111
|
+
{
|
|
112
|
+
"api_key": request.api_key,
|
|
113
|
+
"api_url": request.api_url,
|
|
114
|
+
"inference_url": request.inference_url,
|
|
115
|
+
"agent_name": agent_name,
|
|
116
|
+
"message": arguments.get("message") or arguments.get("input") or "",
|
|
117
|
+
"context": arguments.get("context") or {},
|
|
118
|
+
},
|
|
119
|
+
start_to_close_timeout=timedelta(minutes=2),
|
|
120
|
+
)
|
|
121
|
+
else:
|
|
122
|
+
server_name, tool_name = _split_tool_name(tool_call_name)
|
|
123
|
+
tool_result = await workflow.execute_activity(
|
|
124
|
+
mcp_tool_activity,
|
|
125
|
+
{
|
|
126
|
+
"api_key": request.api_key,
|
|
127
|
+
"api_url": request.api_url,
|
|
128
|
+
"inference_url": request.inference_url,
|
|
129
|
+
"server_name": server_name,
|
|
130
|
+
"tool_name": tool_name,
|
|
131
|
+
"arguments": arguments,
|
|
132
|
+
},
|
|
133
|
+
start_to_close_timeout=timedelta(minutes=5),
|
|
134
|
+
)
|
|
135
|
+
trace_events.extend(_events_from_activity(tool_result))
|
|
136
|
+
messages.append(
|
|
137
|
+
{
|
|
138
|
+
"role": "tool",
|
|
139
|
+
"tool_call_id": tool_call.get("id"),
|
|
140
|
+
"content": str(tool_result.get("result")),
|
|
141
|
+
}
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
return AgentRunResult(
|
|
145
|
+
output={"role": "assistant", "content": "Max steps reached."},
|
|
146
|
+
trace={"events": trace_events},
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
return TuningAgentWorkflow
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _client_from_payload(payload: dict[str, Any]) -> TuningClient:
|
|
153
|
+
return TuningClient(
|
|
154
|
+
api_key=payload["api_key"],
|
|
155
|
+
api_url=payload.get("api_url", "https://app.tuningengines.com"),
|
|
156
|
+
inference_url=payload.get("inference_url", "https://api.tuningengines.com/v1"),
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def _split_tool_name(name: str) -> tuple[str, str]:
|
|
161
|
+
if "__" in name:
|
|
162
|
+
return tuple(name.split("__", 1)) # type: ignore[return-value]
|
|
163
|
+
if "." in name:
|
|
164
|
+
return tuple(name.split(".", 1)) # type: ignore[return-value]
|
|
165
|
+
raise ValueError(f"Tool name must include server and tool, got {name!r}")
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _events_from_activity(payload: Any) -> list[Any]:
|
|
169
|
+
if isinstance(payload, dict):
|
|
170
|
+
trace = payload.get("trace") or {}
|
|
171
|
+
return trace.get("events") or []
|
|
172
|
+
return []
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
import uuid
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(slots=True)
|
|
10
|
+
class TraceEvent:
|
|
11
|
+
id: str
|
|
12
|
+
type: str
|
|
13
|
+
status: str
|
|
14
|
+
at: float
|
|
15
|
+
parent_id: str | None = None
|
|
16
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(slots=True)
|
|
20
|
+
class TraceRecorder:
|
|
21
|
+
"""In-process trace collector for agent runs.
|
|
22
|
+
|
|
23
|
+
This captures the full SDK-side causal chain. Rails already records
|
|
24
|
+
gateway usage, request capture, audit, and token economics. Persist these
|
|
25
|
+
events in your app or forward them once a public trace-ingest endpoint is
|
|
26
|
+
added.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
run_id: str = field(default_factory=lambda: f"run_{uuid.uuid4().hex}")
|
|
30
|
+
events: list[TraceEvent] = field(default_factory=list)
|
|
31
|
+
|
|
32
|
+
def start(self, event_type: str, metadata: dict[str, Any] | None = None, parent_id: str | None = None) -> str:
|
|
33
|
+
event_id = f"evt_{uuid.uuid4().hex}"
|
|
34
|
+
self.events.append(
|
|
35
|
+
TraceEvent(
|
|
36
|
+
id=event_id,
|
|
37
|
+
type=event_type,
|
|
38
|
+
status="started",
|
|
39
|
+
at=time.time(),
|
|
40
|
+
parent_id=parent_id,
|
|
41
|
+
metadata={"run_id": self.run_id, **(metadata or {})},
|
|
42
|
+
)
|
|
43
|
+
)
|
|
44
|
+
return event_id
|
|
45
|
+
|
|
46
|
+
def finish(self, event_id: str, metadata: dict[str, Any] | None = None) -> None:
|
|
47
|
+
self.events.append(
|
|
48
|
+
TraceEvent(
|
|
49
|
+
id=f"evt_{uuid.uuid4().hex}",
|
|
50
|
+
type="span.finish",
|
|
51
|
+
status="succeeded",
|
|
52
|
+
at=time.time(),
|
|
53
|
+
parent_id=event_id,
|
|
54
|
+
metadata={"run_id": self.run_id, **(metadata or {})},
|
|
55
|
+
)
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def error(self, event_id: str, exc: BaseException) -> None:
|
|
59
|
+
self.events.append(
|
|
60
|
+
TraceEvent(
|
|
61
|
+
id=f"evt_{uuid.uuid4().hex}",
|
|
62
|
+
type="span.error",
|
|
63
|
+
status="failed",
|
|
64
|
+
at=time.time(),
|
|
65
|
+
parent_id=event_id,
|
|
66
|
+
metadata={
|
|
67
|
+
"run_id": self.run_id,
|
|
68
|
+
"error_type": exc.__class__.__name__,
|
|
69
|
+
"error": str(exc),
|
|
70
|
+
},
|
|
71
|
+
)
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def as_dict(self) -> dict[str, Any]:
|
|
75
|
+
return {
|
|
76
|
+
"run_id": self.run_id,
|
|
77
|
+
"events": [
|
|
78
|
+
{
|
|
79
|
+
"id": event.id,
|
|
80
|
+
"type": event.type,
|
|
81
|
+
"status": event.status,
|
|
82
|
+
"at": event.at,
|
|
83
|
+
"parent_id": event.parent_id,
|
|
84
|
+
"metadata": event.metadata,
|
|
85
|
+
}
|
|
86
|
+
for event in self.events
|
|
87
|
+
],
|
|
88
|
+
}
|