tuningengines-cli 0.3.6 → 0.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +191 -5
- package/dist/cli.js +11 -1
- package/dist/cli.js.map +1 -1
- package/dist/client.d.ts +84 -0
- package/dist/client.d.ts.map +1 -1
- package/dist/client.js +138 -0
- package/dist/client.js.map +1 -1
- package/dist/commands/agents.d.ts +4 -0
- package/dist/commands/agents.d.ts.map +1 -0
- package/dist/commands/agents.js +72 -0
- package/dist/commands/agents.js.map +1 -0
- package/dist/commands/datasets.d.ts +4 -0
- package/dist/commands/datasets.d.ts.map +1 -0
- package/dist/commands/datasets.js +154 -0
- package/dist/commands/datasets.js.map +1 -0
- package/dist/commands/evaluations.d.ts +4 -0
- package/dist/commands/evaluations.d.ts.map +1 -0
- package/dist/commands/evaluations.js +168 -0
- package/dist/commands/evaluations.js.map +1 -0
- package/dist/commands/inference.d.ts +4 -0
- package/dist/commands/inference.d.ts.map +1 -0
- package/dist/commands/inference.js +92 -0
- package/dist/commands/inference.js.map +1 -0
- package/dist/commands/tenant.d.ts +4 -0
- package/dist/commands/tenant.d.ts.map +1 -0
- package/dist/commands/tenant.js +326 -0
- package/dist/commands/tenant.js.map +1 -0
- package/dist/mcp.d.ts.map +1 -1
- package/dist/mcp.js +349 -27
- package/dist/mcp.js.map +1 -1
- package/package.json +3 -2
- package/packages/tuning-agents/README.md +168 -0
- package/packages/tuning-agents/examples/langgraph_agent.py +25 -0
- package/packages/tuning-agents/examples/temporal_worker.py +27 -0
- package/packages/tuning-agents/pyproject.toml +38 -0
- package/packages/tuning-agents/tests/test_mcp.py +40 -0
- package/packages/tuning-agents/tests/test_trace.py +14 -0
- package/packages/tuning-agents/tuning_agents/__init__.py +11 -0
- package/packages/tuning-agents/tuning_agents/client.py +275 -0
- package/packages/tuning-agents/tuning_agents/langgraph.py +108 -0
- package/packages/tuning-agents/tuning_agents/mcp.py +208 -0
- package/packages/tuning-agents/tuning_agents/resources.py +26 -0
- package/packages/tuning-agents/tuning_agents/temporal.py +172 -0
- package/packages/tuning-agents/tuning_agents/trace.py +88 -0
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
# tuning-agents
|
|
2
|
+
|
|
3
|
+
Governed agent runtime adapters for Tuning Engines.
|
|
4
|
+
|
|
5
|
+
This package keeps orchestration outside Rails while making agent runtimes use
|
|
6
|
+
Tuning Engines for the things it already does well:
|
|
7
|
+
|
|
8
|
+
- OpenAI-compatible model access through the inference gateway
|
|
9
|
+
- MCP tool discovery and execution through `/v1/mcp/tools*`
|
|
10
|
+
- A2A tenant-agent dispatch through `/v1/agents/{name}/message`
|
|
11
|
+
- Agent/skill OpenAI tool specs that line up with proxy RBAC and AGT policy
|
|
12
|
+
- Registry/RBAC/governance enforcement at the gateway
|
|
13
|
+
- Usage, request capture, auditability, and token economics
|
|
14
|
+
- Client-side causal traces for LLM calls, MCP calls, LangGraph runs, and
|
|
15
|
+
Temporal activities
|
|
16
|
+
|
|
17
|
+
## Install
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
pip install tuning-agents[langgraph]
|
|
21
|
+
pip install tuning-agents[temporal]
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
From this repository:
|
|
25
|
+
|
|
26
|
+
```bash
|
|
27
|
+
pip install -e packages/tuning-agents[langgraph,temporal]
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
## LangGraph
|
|
31
|
+
|
|
32
|
+
LangGraph provides the actual agent loop, checkpoints, memory, interrupts, and
|
|
33
|
+
human-in-the-loop workflow. Tuning Engines remains the governed model/tool
|
|
34
|
+
gateway.
|
|
35
|
+
|
|
36
|
+
```python
|
|
37
|
+
from langgraph.checkpoint.memory import InMemorySaver
|
|
38
|
+
|
|
39
|
+
from tuning_agents import TuningClient
|
|
40
|
+
from tuning_agents.langgraph import create_tuning_langgraph_agent, invoke_with_trace
|
|
41
|
+
|
|
42
|
+
client = TuningClient(api_key="sk-te-...", inference_url="https://api.tuningengines.com/v1")
|
|
43
|
+
|
|
44
|
+
agent = create_tuning_langgraph_agent(
|
|
45
|
+
client,
|
|
46
|
+
model="llama-3.3-70b-fp8",
|
|
47
|
+
agent_names=["billing-escalation"],
|
|
48
|
+
checkpointer=InMemorySaver(),
|
|
49
|
+
interrupt_before=["tools"], # optional approval gate before tool execution
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
result = invoke_with_trace(
|
|
53
|
+
client,
|
|
54
|
+
agent,
|
|
55
|
+
[{"role": "user", "content": "Use the registry tools to summarize my latest jobs."}],
|
|
56
|
+
thread_id="customer-123",
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
print(result)
|
|
60
|
+
print(client.trace.as_dict())
|
|
61
|
+
|
|
62
|
+
# Store the runtime trace in Tuning Engines.
|
|
63
|
+
client.flush_trace(name="ticket-triage", runtime="langgraph", status="succeeded")
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
The LangGraph adapter exposes two executable resource classes:
|
|
67
|
+
|
|
68
|
+
- MCP tools discovered from the Tuning Engines proxy
|
|
69
|
+
- Registered tenant agents passed via `agent_names`, executed through
|
|
70
|
+
`/v1/agents/{name}/message`
|
|
71
|
+
|
|
72
|
+
Skills are different: they are governed prompt/workflow bundles represented as
|
|
73
|
+
OpenAI tool specs. Use `ResourceManifest.openai_tools()` when you want the proxy
|
|
74
|
+
to enforce skill access on a direct chat-completions call.
|
|
75
|
+
|
|
76
|
+
```python
|
|
77
|
+
from tuning_agents.resources import ResourceManifest
|
|
78
|
+
|
|
79
|
+
manifest = ResourceManifest(
|
|
80
|
+
model="llama-3.3-70b-fp8",
|
|
81
|
+
agents={"billing-escalation": "Escalate complex billing issues."},
|
|
82
|
+
skills={"analytics": "Run the tenant analytics skill."},
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
resp = client.chat(
|
|
86
|
+
model=manifest.model,
|
|
87
|
+
messages=[{"role": "user", "content": "Analyze this ticket and escalate if needed."}],
|
|
88
|
+
tools=manifest.openai_tools(),
|
|
89
|
+
)
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
## Temporal
|
|
93
|
+
|
|
94
|
+
Temporal provides durable execution, retries, resume-after-crash, schedules, and
|
|
95
|
+
workflow history. The provided workflow is intentionally small: each LLM turn and
|
|
96
|
+
MCP tool call runs as an activity, so Temporal owns durability and Tuning Engines
|
|
97
|
+
owns model/tool governance.
|
|
98
|
+
|
|
99
|
+
```python
|
|
100
|
+
from temporalio.client import Client
|
|
101
|
+
from temporalio.worker import Worker
|
|
102
|
+
|
|
103
|
+
from tuning_agents.temporal import (
|
|
104
|
+
AgentRunInput,
|
|
105
|
+
agent_message_activity,
|
|
106
|
+
chat_completion_activity,
|
|
107
|
+
define_temporal_workflow,
|
|
108
|
+
mcp_tool_activity,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
TuningAgentWorkflow = define_temporal_workflow()
|
|
112
|
+
|
|
113
|
+
async def main():
|
|
114
|
+
temporal = await Client.connect("localhost:7233")
|
|
115
|
+
worker = Worker(
|
|
116
|
+
temporal,
|
|
117
|
+
task_queue="tuning-agents",
|
|
118
|
+
workflows=[TuningAgentWorkflow],
|
|
119
|
+
activities=[chat_completion_activity, mcp_tool_activity, agent_message_activity],
|
|
120
|
+
)
|
|
121
|
+
await worker.run()
|
|
122
|
+
```
|
|
123
|
+
|
|
124
|
+
Start a run:
|
|
125
|
+
|
|
126
|
+
```python
|
|
127
|
+
handle = await temporal.start_workflow(
|
|
128
|
+
TuningAgentWorkflow.run,
|
|
129
|
+
AgentRunInput(
|
|
130
|
+
api_key="sk-te-...",
|
|
131
|
+
model="llama-3.3-70b-fp8",
|
|
132
|
+
messages=[{"role": "user", "content": "Check available tools and answer."}],
|
|
133
|
+
),
|
|
134
|
+
id="agent-run-001",
|
|
135
|
+
task_queue="tuning-agents",
|
|
136
|
+
)
|
|
137
|
+
```
|
|
138
|
+
|
|
139
|
+
## Trace Semantics
|
|
140
|
+
|
|
141
|
+
This SDK captures the full client/runtime-side causal trace:
|
|
142
|
+
|
|
143
|
+
- LangGraph agent creation/invocation
|
|
144
|
+
- LLM calls
|
|
145
|
+
- MCP tool discovery and execution
|
|
146
|
+
- A2A agent dispatches
|
|
147
|
+
- Temporal workflow activities
|
|
148
|
+
- Errors and latency metadata
|
|
149
|
+
|
|
150
|
+
Rails/proxy already capture the gateway side: inference usage, request capture,
|
|
151
|
+
audit logs, policy decisions, token counts, and billing attribution. The SDK
|
|
152
|
+
captures the runtime side and can persist it with:
|
|
153
|
+
|
|
154
|
+
```python
|
|
155
|
+
client.flush_trace(name="support-agent", runtime="langgraph", status="succeeded")
|
|
156
|
+
```
|
|
157
|
+
|
|
158
|
+
That sends events to `POST /api/v1/traces` using the same `TE_API_KEY` auth as
|
|
159
|
+
the CLI/MCP server.
|
|
160
|
+
|
|
161
|
+
## Why this exists
|
|
162
|
+
|
|
163
|
+
The Rails app stays the control plane. This package gives customers a portable
|
|
164
|
+
runtime layer:
|
|
165
|
+
|
|
166
|
+
- LangGraph for agent loops, state, memory, interrupts, and checkpoints
|
|
167
|
+
- Temporal for crash-proof durable execution
|
|
168
|
+
- Tuning Engines for governance, registries, agents, skills, MCP, routing, usage, and economics
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from langgraph.checkpoint.memory import InMemorySaver
|
|
2
|
+
|
|
3
|
+
from tuning_agents import TuningClient
|
|
4
|
+
from tuning_agents.langgraph import create_tuning_langgraph_agent, invoke_with_trace
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
client = TuningClient()
|
|
8
|
+
|
|
9
|
+
agent = create_tuning_langgraph_agent(
|
|
10
|
+
client,
|
|
11
|
+
model="llama-3.3-70b-fp8",
|
|
12
|
+
agent_names=["billing-escalation"],
|
|
13
|
+
checkpointer=InMemorySaver(),
|
|
14
|
+
interrupt_before=["tools"],
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
result = invoke_with_trace(
|
|
18
|
+
client,
|
|
19
|
+
agent,
|
|
20
|
+
[{"role": "user", "content": "List the available governed tools and summarize what you can do."}],
|
|
21
|
+
thread_id="demo-thread",
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
print(result)
|
|
25
|
+
print(client.trace.as_dict())
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
|
|
3
|
+
from temporalio.client import Client
|
|
4
|
+
from temporalio.worker import Worker
|
|
5
|
+
|
|
6
|
+
from tuning_agents.temporal import (
|
|
7
|
+
agent_message_activity,
|
|
8
|
+
chat_completion_activity,
|
|
9
|
+
define_temporal_workflow,
|
|
10
|
+
mcp_tool_activity,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
async def main() -> None:
|
|
15
|
+
temporal = await Client.connect("localhost:7233")
|
|
16
|
+
workflow = define_temporal_workflow()
|
|
17
|
+
worker = Worker(
|
|
18
|
+
temporal,
|
|
19
|
+
task_queue="tuning-agents",
|
|
20
|
+
workflows=[workflow],
|
|
21
|
+
activities=[chat_completion_activity, mcp_tool_activity, agent_message_activity],
|
|
22
|
+
)
|
|
23
|
+
await worker.run()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
if __name__ == "__main__":
|
|
27
|
+
asyncio.run(main())
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling>=1.24"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "tuning-agents"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Governed agent runtime adapters for Tuning Engines, LangGraph, Temporal, and MCP."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
|
+
license = "Apache-2.0"
|
|
12
|
+
authors = [{ name = "Tuning Engines" }]
|
|
13
|
+
dependencies = [
|
|
14
|
+
"httpx>=0.27",
|
|
15
|
+
"openai>=1.40",
|
|
16
|
+
"pydantic>=2.7",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
[project.optional-dependencies]
|
|
20
|
+
langgraph = [
|
|
21
|
+
"langgraph>=0.2",
|
|
22
|
+
"langchain-openai>=0.1.20",
|
|
23
|
+
"langchain-core>=0.2.30",
|
|
24
|
+
]
|
|
25
|
+
temporal = [
|
|
26
|
+
"temporalio>=1.7",
|
|
27
|
+
]
|
|
28
|
+
dev = [
|
|
29
|
+
"pytest>=8.0",
|
|
30
|
+
"pytest-asyncio>=0.23",
|
|
31
|
+
"ruff>=0.6",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
[tool.hatch.build.targets.wheel]
|
|
35
|
+
packages = ["tuning_agents"]
|
|
36
|
+
|
|
37
|
+
[tool.ruff]
|
|
38
|
+
line-length = 100
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from tuning_agents.mcp import agent_tool_spec, normalize_mcp_tools, pydantic_model_from_json_schema, skill_tool_spec
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def test_normalize_flat_tools():
|
|
5
|
+
tools = normalize_mcp_tools({"tools": [{"name": "x", "server_name": "s"}]})
|
|
6
|
+
|
|
7
|
+
assert tools == [{"name": "x", "server_name": "s"}]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def test_normalize_grouped_tools():
|
|
11
|
+
tools = normalize_mcp_tools({"servers": [{"name": "s", "tools": [{"name": "x"}]}]})
|
|
12
|
+
|
|
13
|
+
assert tools == [{"name": "x", "server_name": "s"}]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def test_pydantic_model_from_json_schema():
|
|
17
|
+
model = pydantic_model_from_json_schema(
|
|
18
|
+
"Args",
|
|
19
|
+
{
|
|
20
|
+
"type": "object",
|
|
21
|
+
"required": ["query"],
|
|
22
|
+
"properties": {
|
|
23
|
+
"query": {"type": "string"},
|
|
24
|
+
"limit": {"type": "integer"},
|
|
25
|
+
},
|
|
26
|
+
},
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
parsed = model(query="hello", limit=3)
|
|
30
|
+
|
|
31
|
+
assert parsed.query == "hello"
|
|
32
|
+
assert parsed.limit == 3
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def test_agent_and_skill_tool_specs():
|
|
36
|
+
agent = agent_tool_spec("billing-escalation")
|
|
37
|
+
skill = skill_tool_spec("analytics")
|
|
38
|
+
|
|
39
|
+
assert agent["function"]["name"] == "billing-escalation"
|
|
40
|
+
assert skill["function"]["name"] == "analytics"
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from tuning_agents.trace import TraceRecorder
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def test_trace_recorder_lifecycle():
|
|
5
|
+
trace = TraceRecorder(run_id="run_test")
|
|
6
|
+
span = trace.start("unit", {"ok": True})
|
|
7
|
+
trace.finish(span, {"duration_ms": 1})
|
|
8
|
+
|
|
9
|
+
data = trace.as_dict()
|
|
10
|
+
|
|
11
|
+
assert data["run_id"] == "run_test"
|
|
12
|
+
assert len(data["events"]) == 2
|
|
13
|
+
assert data["events"][0]["metadata"]["ok"] is True
|
|
14
|
+
assert data["events"][1]["parent_id"] == span
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .trace import TraceEvent, TraceRecorder
|
|
2
|
+
|
|
3
|
+
__all__ = ["TuningClient", "TraceEvent", "TraceRecorder"]
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def __getattr__(name: str):
|
|
7
|
+
if name == "TuningClient":
|
|
8
|
+
from .client import TuningClient
|
|
9
|
+
|
|
10
|
+
return TuningClient
|
|
11
|
+
raise AttributeError(name)
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
import uuid
|
|
6
|
+
from collections.abc import Mapping
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import httpx
|
|
11
|
+
from openai import OpenAI
|
|
12
|
+
|
|
13
|
+
from .trace import TraceRecorder
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TuningError(RuntimeError):
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(slots=True)
|
|
21
|
+
class TuningClient:
|
|
22
|
+
"""Small client for Tuning Engines control-plane and inference APIs.
|
|
23
|
+
|
|
24
|
+
The API key can be a long-lived inference key (`sk-te-...`) when talking
|
|
25
|
+
directly to the proxy, or a platform API token when using `/api/v1/*`.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
api_key: str | None = None
|
|
29
|
+
api_url: str = "https://app.tuningengines.com"
|
|
30
|
+
inference_url: str = "https://api.tuningengines.com/v1"
|
|
31
|
+
timeout: float = 60.0
|
|
32
|
+
user_agent: str = "tuning-agents/0.1.0"
|
|
33
|
+
trace: TraceRecorder = field(default_factory=TraceRecorder)
|
|
34
|
+
|
|
35
|
+
def __post_init__(self) -> None:
|
|
36
|
+
self.api_key = self.api_key or os.getenv("TE_API_KEY")
|
|
37
|
+
if not self.api_key:
|
|
38
|
+
raise ValueError("api_key is required or TE_API_KEY must be set")
|
|
39
|
+
self.api_url = self.api_url.rstrip("/")
|
|
40
|
+
self.inference_url = self.inference_url.rstrip("/")
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def openai(self) -> OpenAI:
|
|
44
|
+
return OpenAI(api_key=self.api_key, base_url=self.inference_url, timeout=self.timeout)
|
|
45
|
+
|
|
46
|
+
def request(
|
|
47
|
+
self,
|
|
48
|
+
method: str,
|
|
49
|
+
path: str,
|
|
50
|
+
*,
|
|
51
|
+
json: Mapping[str, Any] | None = None,
|
|
52
|
+
base_url: str | None = None,
|
|
53
|
+
trace_type: str = "http",
|
|
54
|
+
) -> Any:
|
|
55
|
+
url = f"{(base_url or self.api_url).rstrip('/')}/{path.lstrip('/')}"
|
|
56
|
+
span_id = self.trace.start(trace_type, {"method": method, "url": url})
|
|
57
|
+
started = time.perf_counter()
|
|
58
|
+
try:
|
|
59
|
+
with httpx.Client(timeout=self.timeout, headers=self._headers()) as client:
|
|
60
|
+
response = client.request(method, url, json=json)
|
|
61
|
+
payload = self._parse_response(response)
|
|
62
|
+
self.trace.finish(
|
|
63
|
+
span_id,
|
|
64
|
+
{
|
|
65
|
+
"status_code": response.status_code,
|
|
66
|
+
"duration_ms": round((time.perf_counter() - started) * 1000, 2),
|
|
67
|
+
},
|
|
68
|
+
)
|
|
69
|
+
return payload
|
|
70
|
+
except Exception as exc:
|
|
71
|
+
self.trace.error(span_id, exc)
|
|
72
|
+
raise
|
|
73
|
+
|
|
74
|
+
async def arequest(
|
|
75
|
+
self,
|
|
76
|
+
method: str,
|
|
77
|
+
path: str,
|
|
78
|
+
*,
|
|
79
|
+
json: Mapping[str, Any] | None = None,
|
|
80
|
+
base_url: str | None = None,
|
|
81
|
+
trace_type: str = "http",
|
|
82
|
+
) -> Any:
|
|
83
|
+
url = f"{(base_url or self.api_url).rstrip('/')}/{path.lstrip('/')}"
|
|
84
|
+
span_id = self.trace.start(trace_type, {"method": method, "url": url})
|
|
85
|
+
started = time.perf_counter()
|
|
86
|
+
try:
|
|
87
|
+
async with httpx.AsyncClient(timeout=self.timeout, headers=self._headers()) as client:
|
|
88
|
+
response = await client.request(method, url, json=json)
|
|
89
|
+
payload = self._parse_response(response)
|
|
90
|
+
self.trace.finish(
|
|
91
|
+
span_id,
|
|
92
|
+
{
|
|
93
|
+
"status_code": response.status_code,
|
|
94
|
+
"duration_ms": round((time.perf_counter() - started) * 1000, 2),
|
|
95
|
+
},
|
|
96
|
+
)
|
|
97
|
+
return payload
|
|
98
|
+
except Exception as exc:
|
|
99
|
+
self.trace.error(span_id, exc)
|
|
100
|
+
raise
|
|
101
|
+
|
|
102
|
+
def chat(self, *, model: str, messages: list[dict[str, Any]], **kwargs: Any) -> Any:
|
|
103
|
+
span_id = self.trace.start("llm", {"model": model})
|
|
104
|
+
try:
|
|
105
|
+
clean_kwargs = {key: value for key, value in kwargs.items() if value is not None}
|
|
106
|
+
response = self.openai.chat.completions.create(model=model, messages=messages, **clean_kwargs)
|
|
107
|
+
usage = getattr(response, "usage", None)
|
|
108
|
+
self.trace.finish(
|
|
109
|
+
span_id,
|
|
110
|
+
{
|
|
111
|
+
"model": getattr(response, "model", model),
|
|
112
|
+
"usage": usage.model_dump() if hasattr(usage, "model_dump") else usage,
|
|
113
|
+
},
|
|
114
|
+
)
|
|
115
|
+
return response
|
|
116
|
+
except Exception as exc:
|
|
117
|
+
self.trace.error(span_id, exc)
|
|
118
|
+
raise
|
|
119
|
+
|
|
120
|
+
def list_models(self) -> Any:
|
|
121
|
+
return self.request("GET", "/api/v1/inference/models", trace_type="control")
|
|
122
|
+
|
|
123
|
+
def list_training_agents(
|
|
124
|
+
self,
|
|
125
|
+
*,
|
|
126
|
+
category: str | None = None,
|
|
127
|
+
include_disabled: bool = False,
|
|
128
|
+
) -> Any:
|
|
129
|
+
params = []
|
|
130
|
+
if category:
|
|
131
|
+
params.append(f"category={category}")
|
|
132
|
+
if include_disabled:
|
|
133
|
+
params.append("include_disabled=true")
|
|
134
|
+
query = f"?{'&'.join(params)}" if params else ""
|
|
135
|
+
return self.request("GET", f"/api/v1/agents{query}", trace_type="control")
|
|
136
|
+
|
|
137
|
+
def list_usage(self, *, model: str | None = None, limit: int | None = None) -> Any:
|
|
138
|
+
params = []
|
|
139
|
+
if model:
|
|
140
|
+
params.append(f"model={model}")
|
|
141
|
+
if limit:
|
|
142
|
+
params.append(f"limit={limit}")
|
|
143
|
+
query = f"?{'&'.join(params)}" if params else ""
|
|
144
|
+
return self.request("GET", f"/api/v1/inference/usage{query}", trace_type="control")
|
|
145
|
+
|
|
146
|
+
def bulk_import_resources(
|
|
147
|
+
self,
|
|
148
|
+
*,
|
|
149
|
+
target_type: str,
|
|
150
|
+
rows: list[Mapping[str, Any]],
|
|
151
|
+
dry_run: bool = True,
|
|
152
|
+
) -> Any:
|
|
153
|
+
return self.request(
|
|
154
|
+
"POST",
|
|
155
|
+
"/api/v1/bulk_imports",
|
|
156
|
+
json={"target_type": target_type, "rows": rows, "dry_run": dry_run},
|
|
157
|
+
trace_type="control",
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def flush_trace(
|
|
161
|
+
self,
|
|
162
|
+
*,
|
|
163
|
+
name: str | None = None,
|
|
164
|
+
runtime: str = "custom",
|
|
165
|
+
status: str = "running",
|
|
166
|
+
metadata: Mapping[str, Any] | None = None,
|
|
167
|
+
) -> Any:
|
|
168
|
+
trace = self.trace.as_dict()
|
|
169
|
+
return self.request(
|
|
170
|
+
"POST",
|
|
171
|
+
"/api/v1/traces",
|
|
172
|
+
json={
|
|
173
|
+
"run_id": trace["run_id"],
|
|
174
|
+
"name": name,
|
|
175
|
+
"runtime": runtime,
|
|
176
|
+
"status": status,
|
|
177
|
+
"metadata": dict(metadata or {}),
|
|
178
|
+
"events": trace["events"],
|
|
179
|
+
},
|
|
180
|
+
trace_type="control",
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def call_agent(
|
|
184
|
+
self,
|
|
185
|
+
*,
|
|
186
|
+
agent_name: str,
|
|
187
|
+
message: str,
|
|
188
|
+
context: Mapping[str, Any] | None = None,
|
|
189
|
+
) -> Any:
|
|
190
|
+
return self.request(
|
|
191
|
+
"POST",
|
|
192
|
+
f"/v1/agents/{agent_name}/message",
|
|
193
|
+
base_url=self.inference_url.removesuffix("/v1"),
|
|
194
|
+
json={"message": message, "context": dict(context or {})},
|
|
195
|
+
trace_type="agent",
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
async def acall_agent(
|
|
199
|
+
self,
|
|
200
|
+
*,
|
|
201
|
+
agent_name: str,
|
|
202
|
+
message: str,
|
|
203
|
+
context: Mapping[str, Any] | None = None,
|
|
204
|
+
) -> Any:
|
|
205
|
+
return await self.arequest(
|
|
206
|
+
"POST",
|
|
207
|
+
f"/v1/agents/{agent_name}/message",
|
|
208
|
+
base_url=self.inference_url.removesuffix("/v1"),
|
|
209
|
+
json={"message": message, "context": dict(context or {})},
|
|
210
|
+
trace_type="agent",
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
def list_mcp_tools(self) -> Any:
|
|
214
|
+
return self.request("GET", "/v1/mcp/tools", base_url=self.inference_url.removesuffix("/v1"), trace_type="mcp")
|
|
215
|
+
|
|
216
|
+
def call_mcp_tool(
|
|
217
|
+
self,
|
|
218
|
+
*,
|
|
219
|
+
server_name: str,
|
|
220
|
+
tool_name: str,
|
|
221
|
+
arguments: Mapping[str, Any] | None = None,
|
|
222
|
+
) -> Any:
|
|
223
|
+
return self.request(
|
|
224
|
+
"POST",
|
|
225
|
+
"/v1/mcp/tools/call",
|
|
226
|
+
base_url=self.inference_url.removesuffix("/v1"),
|
|
227
|
+
json={
|
|
228
|
+
"server_name": server_name,
|
|
229
|
+
"tool_name": tool_name,
|
|
230
|
+
"arguments": dict(arguments or {}),
|
|
231
|
+
},
|
|
232
|
+
trace_type="mcp",
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
async def acall_mcp_tool(
|
|
236
|
+
self,
|
|
237
|
+
*,
|
|
238
|
+
server_name: str,
|
|
239
|
+
tool_name: str,
|
|
240
|
+
arguments: Mapping[str, Any] | None = None,
|
|
241
|
+
) -> Any:
|
|
242
|
+
return await self.arequest(
|
|
243
|
+
"POST",
|
|
244
|
+
"/v1/mcp/tools/call",
|
|
245
|
+
base_url=self.inference_url.removesuffix("/v1"),
|
|
246
|
+
json={
|
|
247
|
+
"server_name": server_name,
|
|
248
|
+
"tool_name": tool_name,
|
|
249
|
+
"arguments": dict(arguments or {}),
|
|
250
|
+
},
|
|
251
|
+
trace_type="mcp",
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
def new_run_id(self, prefix: str = "run") -> str:
|
|
255
|
+
return f"{prefix}_{uuid.uuid4().hex}"
|
|
256
|
+
|
|
257
|
+
def _headers(self) -> dict[str, str]:
|
|
258
|
+
return {
|
|
259
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
260
|
+
"Accept": "application/json",
|
|
261
|
+
"Content-Type": "application/json",
|
|
262
|
+
"User-Agent": self.user_agent,
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
@staticmethod
|
|
266
|
+
def _parse_response(response: httpx.Response) -> Any:
|
|
267
|
+
text = response.text
|
|
268
|
+
try:
|
|
269
|
+
payload = response.json() if text else {}
|
|
270
|
+
except ValueError:
|
|
271
|
+
payload = text
|
|
272
|
+
if response.status_code >= 400:
|
|
273
|
+
message = payload.get("error") if isinstance(payload, dict) else payload
|
|
274
|
+
raise TuningError(f"Tuning Engines API error {response.status_code}: {message}")
|
|
275
|
+
return payload
|