roder-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- roder_sdk/__init__.py +24 -0
- roder_sdk/agent.py +255 -0
- roder_sdk/client.py +60 -0
- roder_sdk/errors.py +24 -0
- roder_sdk/events.py +46 -0
- roder_sdk/hosted.py +121 -0
- roder_sdk/run.py +63 -0
- roder_sdk/transports.py +158 -0
- roder_sdk/types_generated.py +2906 -0
- roder_sdk-0.1.0.dist-info/METADATA +36 -0
- roder_sdk-0.1.0.dist-info/RECORD +13 -0
- roder_sdk-0.1.0.dist-info/WHEEL +5 -0
- roder_sdk-0.1.0.dist-info/top_level.txt +1 -0
roder_sdk/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from .agent import RoderAgent
|
|
2
|
+
from .client import RoderRpcClient
|
|
3
|
+
from .events import normalize_notification
|
|
4
|
+
from .hosted import HostedClient
|
|
5
|
+
from .errors import RoderRpcError, RoderTransportError
|
|
6
|
+
from .run import RoderRun
|
|
7
|
+
from .transports import InMemoryTransport, LocalProcessTransport, WebSocketTransport
|
|
8
|
+
from .types_generated import APP_SERVER_MANIFEST, APP_SERVER_METHODS, AppServerMethod
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"APP_SERVER_MANIFEST",
|
|
12
|
+
"APP_SERVER_METHODS",
|
|
13
|
+
"AppServerMethod",
|
|
14
|
+
"HostedClient",
|
|
15
|
+
"InMemoryTransport",
|
|
16
|
+
"LocalProcessTransport",
|
|
17
|
+
"RoderAgent",
|
|
18
|
+
"RoderRpcClient",
|
|
19
|
+
"RoderRpcError",
|
|
20
|
+
"RoderRun",
|
|
21
|
+
"RoderTransportError",
|
|
22
|
+
"WebSocketTransport",
|
|
23
|
+
"normalize_notification",
|
|
24
|
+
]
|
roder_sdk/agent.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import inspect
|
|
5
|
+
from collections.abc import Awaitable, Callable
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
|
|
8
|
+
from .client import RoderRpcClient
|
|
9
|
+
from .events import EventMode
|
|
10
|
+
from .run import RoderRun
|
|
11
|
+
from .transports import InMemoryTransport, LocalProcessTransport, RoderTransport, WebSocketTransport
|
|
12
|
+
|
|
13
|
+
ApprovalCallback = Callable[[Any], dict[str, Any] | Awaitable[dict[str, Any]]]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RoderAgent:
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
transport: RoderTransport,
|
|
20
|
+
*,
|
|
21
|
+
cwd: str | None = None,
|
|
22
|
+
model: dict[str, str] | None = None,
|
|
23
|
+
thread_id: str | None = None,
|
|
24
|
+
workspace_id: str | None = None,
|
|
25
|
+
tool_allowlist: list[str] | None = None,
|
|
26
|
+
instructions: str | None = None,
|
|
27
|
+
runner: dict[str, Any] | None = None,
|
|
28
|
+
approvals: dict[str, ApprovalCallback] | None = None,
|
|
29
|
+
event_mode: EventMode = "permissive",
|
|
30
|
+
) -> None:
|
|
31
|
+
self.transport = transport
|
|
32
|
+
self.client = RoderRpcClient(transport)
|
|
33
|
+
self.cwd = cwd
|
|
34
|
+
self.model = model or {}
|
|
35
|
+
self.thread_id = thread_id
|
|
36
|
+
self.workspace_id = workspace_id
|
|
37
|
+
self.tool_allowlist = tool_allowlist
|
|
38
|
+
self.instructions = instructions
|
|
39
|
+
self.runner = runner
|
|
40
|
+
self.approvals = approvals or {}
|
|
41
|
+
self.event_mode: EventMode = event_mode
|
|
42
|
+
self._callback_task: asyncio.Task[None] | None = None
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
async def create(
|
|
46
|
+
cls,
|
|
47
|
+
*,
|
|
48
|
+
local: dict[str, Any] | None = None,
|
|
49
|
+
remote: dict[str, Any] | None = None,
|
|
50
|
+
transport: RoderTransport | None = None,
|
|
51
|
+
cwd: str | None = None,
|
|
52
|
+
model: dict[str, str] | None = None,
|
|
53
|
+
thread_id: str | None = None,
|
|
54
|
+
workspace_id: str | None = None,
|
|
55
|
+
tool_allowlist: list[str] | None = None,
|
|
56
|
+
instructions: str | None = None,
|
|
57
|
+
runner: dict[str, Any] | None = None,
|
|
58
|
+
approvals: dict[str, ApprovalCallback] | None = None,
|
|
59
|
+
event_mode: EventMode = "permissive",
|
|
60
|
+
) -> "RoderAgent":
|
|
61
|
+
resolved = await _resolve_transport(local=local, remote=remote, transport=transport, cwd=cwd)
|
|
62
|
+
agent = cls(
|
|
63
|
+
resolved,
|
|
64
|
+
cwd=cwd,
|
|
65
|
+
model=model,
|
|
66
|
+
thread_id=thread_id,
|
|
67
|
+
workspace_id=workspace_id,
|
|
68
|
+
tool_allowlist=tool_allowlist,
|
|
69
|
+
instructions=instructions,
|
|
70
|
+
runner=runner,
|
|
71
|
+
approvals=approvals,
|
|
72
|
+
event_mode=event_mode,
|
|
73
|
+
)
|
|
74
|
+
agent._start_callback_loop()
|
|
75
|
+
return agent
|
|
76
|
+
|
|
77
|
+
async def __aenter__(self) -> "RoderAgent":
|
|
78
|
+
return self
|
|
79
|
+
|
|
80
|
+
async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None:
|
|
81
|
+
await self.close()
|
|
82
|
+
|
|
83
|
+
async def send(
|
|
84
|
+
self,
|
|
85
|
+
input: str | list[dict[str, Any]],
|
|
86
|
+
*,
|
|
87
|
+
developer_context: str | None = None,
|
|
88
|
+
) -> RoderRun:
|
|
89
|
+
"""
|
|
90
|
+
developer_context is per-turn developer-authority context layered after
|
|
91
|
+
the thread's developerInstructions for this turn only. Never persisted
|
|
92
|
+
with the thread; resend it on each turn that needs it.
|
|
93
|
+
"""
|
|
94
|
+
thread_id = self.thread_id or await self._start_thread()
|
|
95
|
+
self.thread_id = thread_id
|
|
96
|
+
params: dict[str, Any] = {"threadId": thread_id, "input": _normalize_input(input)}
|
|
97
|
+
if developer_context is not None:
|
|
98
|
+
params["developerContext"] = developer_context
|
|
99
|
+
result = await self.client.call("turn/start", params)
|
|
100
|
+
turn_id = _extract_id(result, "turn") or _extract_string(result, "turnId") or _extract_string(result, "id")
|
|
101
|
+
if not turn_id:
|
|
102
|
+
raise RuntimeError("turn/start response did not include a turn id")
|
|
103
|
+
return RoderRun(self.client, thread_id, turn_id, event_mode=self.event_mode)
|
|
104
|
+
|
|
105
|
+
async def list_models(self) -> Any:
|
|
106
|
+
return await self.client.call("model/list")
|
|
107
|
+
|
|
108
|
+
async def list_providers(self) -> Any:
|
|
109
|
+
return await self.client.call("providers/list")
|
|
110
|
+
|
|
111
|
+
async def read_thread(self, thread_id: str | None = None) -> Any:
|
|
112
|
+
selected = thread_id or self.thread_id
|
|
113
|
+
if not selected:
|
|
114
|
+
raise RuntimeError("read_thread requires a thread id")
|
|
115
|
+
return await self.client.call("thread/read", {"threadId": selected})
|
|
116
|
+
|
|
117
|
+
async def list_threads(self) -> Any:
|
|
118
|
+
return await self.client.call("thread/list")
|
|
119
|
+
|
|
120
|
+
async def list_tools(self) -> Any:
|
|
121
|
+
return await self.client.call("tools/list")
|
|
122
|
+
|
|
123
|
+
async def list_commands(self) -> Any:
|
|
124
|
+
return await self.client.call("commands/list")
|
|
125
|
+
|
|
126
|
+
async def close(self) -> None:
|
|
127
|
+
if self._callback_task:
|
|
128
|
+
self._callback_task.cancel()
|
|
129
|
+
await self.client.close()
|
|
130
|
+
|
|
131
|
+
async def _start_thread(self) -> str:
|
|
132
|
+
workspace_id = self.workspace_id or await self._resolve_workspace_id(self.cwd)
|
|
133
|
+
params: dict[str, Any] = {
|
|
134
|
+
"cwd": self.cwd,
|
|
135
|
+
"model": self.model.get("id"),
|
|
136
|
+
"modelProvider": self.model.get("provider"),
|
|
137
|
+
}
|
|
138
|
+
if self.tool_allowlist is not None:
|
|
139
|
+
params["toolAllowlist"] = self.tool_allowlist
|
|
140
|
+
if self.instructions is not None:
|
|
141
|
+
params["developerInstructions"] = self.instructions
|
|
142
|
+
if self.runner is not None:
|
|
143
|
+
params["runner"] = self.runner
|
|
144
|
+
params["workspaceId"] = workspace_id
|
|
145
|
+
result = await self.client.call("thread/start", params)
|
|
146
|
+
thread_id = _extract_id(result, "thread") or _extract_string(result, "threadId") or _extract_string(result, "id")
|
|
147
|
+
if not thread_id:
|
|
148
|
+
raise RuntimeError("thread/start response did not include a thread id")
|
|
149
|
+
return thread_id
|
|
150
|
+
|
|
151
|
+
async def _resolve_workspace_id(self, cwd: str | None) -> str:
|
|
152
|
+
if not cwd:
|
|
153
|
+
raise RuntimeError("starting a thread requires a workspace_id or a cwd to resolve one from")
|
|
154
|
+
listed = await self.client.call("workspace/list", {})
|
|
155
|
+
workspaces = listed.get("workspaces") if isinstance(listed, dict) else None
|
|
156
|
+
for workspace in workspaces if isinstance(workspaces, list) else []:
|
|
157
|
+
if not isinstance(workspace, dict):
|
|
158
|
+
continue
|
|
159
|
+
roots = workspace.get("roots")
|
|
160
|
+
workspace_id = _extract_string(workspace, "id")
|
|
161
|
+
if (
|
|
162
|
+
workspace_id
|
|
163
|
+
and isinstance(roots, list)
|
|
164
|
+
and any(_extract_string(root, "path") == cwd for root in roots)
|
|
165
|
+
):
|
|
166
|
+
return workspace_id
|
|
167
|
+
created = await self.client.call("workspace/create", {"roots": [{"path": cwd}]})
|
|
168
|
+
workspace_id = _extract_id(created, "workspace")
|
|
169
|
+
if not workspace_id:
|
|
170
|
+
raise RuntimeError("workspace/create response did not include a workspace id")
|
|
171
|
+
return workspace_id
|
|
172
|
+
|
|
173
|
+
def _start_callback_loop(self) -> None:
|
|
174
|
+
if not self.approvals:
|
|
175
|
+
return
|
|
176
|
+
self._callback_task = asyncio.create_task(self._callback_loop())
|
|
177
|
+
|
|
178
|
+
async def _callback_loop(self) -> None:
|
|
179
|
+
async for notification in self.client.notifications():
|
|
180
|
+
await self._handle_callback_notification(str(notification.get("method")), notification.get("params"))
|
|
181
|
+
|
|
182
|
+
async def _handle_callback_notification(self, method: str, params: Any) -> None:
|
|
183
|
+
if method == "thread/approvalRequested" and "on_tool_approval" in self.approvals:
|
|
184
|
+
decision = await _maybe_await(self.approvals["on_tool_approval"](params))
|
|
185
|
+
await self.client.call(
|
|
186
|
+
"thread/resolve_approval",
|
|
187
|
+
{
|
|
188
|
+
"approvalId": _extract_string(params, "approvalId"),
|
|
189
|
+
"approved": bool(decision.get("approved")),
|
|
190
|
+
},
|
|
191
|
+
)
|
|
192
|
+
elif method == "thread/userInputRequested" and "on_user_input" in self.approvals:
|
|
193
|
+
decision = await _maybe_await(self.approvals["on_user_input"](params))
|
|
194
|
+
await self.client.call(
|
|
195
|
+
"thread/resolve_user_input",
|
|
196
|
+
{"requestId": _extract_string(params, "requestId"), "answers": decision.get("answers")},
|
|
197
|
+
)
|
|
198
|
+
elif method == "thread/planExitRequested" and "on_plan_exit" in self.approvals:
|
|
199
|
+
decision = await _maybe_await(self.approvals["on_plan_exit"](params))
|
|
200
|
+
await self.client.call(
|
|
201
|
+
"thread/exit_plan",
|
|
202
|
+
{
|
|
203
|
+
"requestId": _extract_string(params, "requestId"),
|
|
204
|
+
"approved": bool(decision.get("approved")),
|
|
205
|
+
},
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
async def _resolve_transport(
|
|
210
|
+
*,
|
|
211
|
+
local: dict[str, Any] | None,
|
|
212
|
+
remote: dict[str, Any] | None,
|
|
213
|
+
transport: RoderTransport | None,
|
|
214
|
+
cwd: str | None,
|
|
215
|
+
) -> RoderTransport:
|
|
216
|
+
if transport:
|
|
217
|
+
return transport
|
|
218
|
+
if remote:
|
|
219
|
+
return await WebSocketTransport.connect(**remote)
|
|
220
|
+
if local:
|
|
221
|
+
return await LocalProcessTransport.create(
|
|
222
|
+
command=local.get("command", "roder"),
|
|
223
|
+
args=local.get("args"),
|
|
224
|
+
cwd=local.get("cwd", cwd),
|
|
225
|
+
env=local.get("env"),
|
|
226
|
+
)
|
|
227
|
+
return InMemoryTransport(
|
|
228
|
+
lambda request: {
|
|
229
|
+
"jsonrpc": "2.0",
|
|
230
|
+
"id": request.get("id"),
|
|
231
|
+
"error": {"code": -32000, "message": "no transport configured"},
|
|
232
|
+
}
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def _normalize_input(input: str | list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
237
|
+
return [{"type": "text", "text": input}] if isinstance(input, str) else input
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _extract_id(value: Any, key: str) -> str | None:
|
|
241
|
+
if isinstance(value, dict) and isinstance(value.get(key), dict):
|
|
242
|
+
return _extract_string(value[key], "id")
|
|
243
|
+
return None
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _extract_string(value: Any, key: str) -> str | None:
|
|
247
|
+
if isinstance(value, dict) and isinstance(value.get(key), str):
|
|
248
|
+
return str(value[key])
|
|
249
|
+
return None
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
async def _maybe_await(value: dict[str, Any] | Awaitable[dict[str, Any]]) -> dict[str, Any]:
|
|
253
|
+
if inspect.isawaitable(value):
|
|
254
|
+
return await cast(Awaitable[dict[str, Any]], value)
|
|
255
|
+
return cast(dict[str, Any], value)
|
roder_sdk/client.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from .errors import RoderRpcError
|
|
7
|
+
from .transports import JsonRpcNotification, RoderTransport
|
|
8
|
+
from .types_generated import APP_SERVER_METHODS, AppServerMethod, JsonRpcRequest, JsonRpcResponse
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class RoderRpcClient:
|
|
12
|
+
def __init__(self, transport: RoderTransport) -> None:
|
|
13
|
+
self._transport = transport
|
|
14
|
+
self._next_id = 1
|
|
15
|
+
self.methods: dict[str, Callable[[Any | None], Any]] = {
|
|
16
|
+
method: self._method_helper(method) for method in APP_SERVER_METHODS
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
async def call(self, method: AppServerMethod, params: Any = None) -> Any:
|
|
20
|
+
request: JsonRpcRequest = {
|
|
21
|
+
"jsonrpc": "2.0",
|
|
22
|
+
"id": self._allocate_id(),
|
|
23
|
+
"method": method,
|
|
24
|
+
}
|
|
25
|
+
if params is not None:
|
|
26
|
+
request["params"] = params
|
|
27
|
+
response = await self.raw_request(request)
|
|
28
|
+
error = response.get("error")
|
|
29
|
+
if error:
|
|
30
|
+
raise RoderRpcError(
|
|
31
|
+
code=int(error.get("code", -32000)),
|
|
32
|
+
message=str(error.get("message", "JSON-RPC error")),
|
|
33
|
+
data=error.get("data"),
|
|
34
|
+
method=method,
|
|
35
|
+
request_id=response.get("id"),
|
|
36
|
+
)
|
|
37
|
+
return response.get("result")
|
|
38
|
+
|
|
39
|
+
async def raw_request(self, request: JsonRpcRequest) -> JsonRpcResponse:
|
|
40
|
+
return await self._transport.request(request)
|
|
41
|
+
|
|
42
|
+
def notifications(self):
|
|
43
|
+
return self._transport.notifications()
|
|
44
|
+
|
|
45
|
+
async def close(self) -> None:
|
|
46
|
+
await self._transport.close()
|
|
47
|
+
|
|
48
|
+
def _allocate_id(self) -> int:
|
|
49
|
+
request_id = self._next_id
|
|
50
|
+
self._next_id += 1
|
|
51
|
+
return request_id
|
|
52
|
+
|
|
53
|
+
def _method_helper(self, method: str) -> Callable[[Any | None], Any]:
|
|
54
|
+
async def helper(params: Any = None) -> Any:
|
|
55
|
+
return await self.call(method, params) # type: ignore[arg-type]
|
|
56
|
+
|
|
57
|
+
return helper
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
__all__ = ["JsonRpcNotification", "RoderRpcClient"]
|
roder_sdk/errors.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class RoderRpcError(Exception):
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
*,
|
|
10
|
+
code: int,
|
|
11
|
+
message: str,
|
|
12
|
+
data: Any = None,
|
|
13
|
+
method: str,
|
|
14
|
+
request_id: str | int | None,
|
|
15
|
+
) -> None:
|
|
16
|
+
super().__init__(message)
|
|
17
|
+
self.code = code
|
|
18
|
+
self.data = data
|
|
19
|
+
self.method = method
|
|
20
|
+
self.request_id = request_id
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class RoderTransportError(Exception):
|
|
24
|
+
pass
|
roder_sdk/events.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Literal, TypedDict
|
|
4
|
+
|
|
5
|
+
JsonRpcNotification = dict[str, Any]
|
|
6
|
+
EventMode = Literal["strict", "permissive"]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class RoderSdkEvent(TypedDict):
|
|
10
|
+
type: str
|
|
11
|
+
raw: JsonRpcNotification
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
EVENT_TYPES: dict[str, str] = {
|
|
15
|
+
"thread/started": "thread.started",
|
|
16
|
+
"thread/status/changed": "thread.status.changed",
|
|
17
|
+
"turn/started": "turn.started",
|
|
18
|
+
"turn/completed": "turn.completed",
|
|
19
|
+
"item/started": "item.started",
|
|
20
|
+
"item/completed": "item.completed",
|
|
21
|
+
"item/agentMessage/delta": "item.delta",
|
|
22
|
+
"item/reasoning/textDelta": "item.delta",
|
|
23
|
+
"item/reasoning/summaryPartAdded": "item.delta",
|
|
24
|
+
"item/reasoning/summaryTextDelta": "item.delta",
|
|
25
|
+
"thread/toolExecutionRequested": "tool_execution.requested",
|
|
26
|
+
"thread/toolExecutionResolved": "tool_execution.resolved",
|
|
27
|
+
"thread/approvalRequested": "approval.requested",
|
|
28
|
+
"thread/approvalResolved": "approval.resolved",
|
|
29
|
+
"thread/userInputRequested": "user_input.requested",
|
|
30
|
+
"thread/userInputResolved": "user_input.resolved",
|
|
31
|
+
"thread/planExitRequested": "plan_exit.requested",
|
|
32
|
+
"thread/planExitResolved": "plan_exit.resolved",
|
|
33
|
+
"command/exec/outputDelta": "command.output_delta",
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def normalize_notification(
|
|
38
|
+
raw: JsonRpcNotification,
|
|
39
|
+
mode: EventMode = "permissive",
|
|
40
|
+
) -> RoderSdkEvent | None:
|
|
41
|
+
event_type = EVENT_TYPES.get(str(raw.get("method")))
|
|
42
|
+
if event_type:
|
|
43
|
+
return {"type": event_type, "raw": raw}
|
|
44
|
+
if mode == "permissive":
|
|
45
|
+
return {"type": "raw.notification", "raw": raw}
|
|
46
|
+
return None
|
roder_sdk/hosted.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""Hosted multi-tenant connection helpers (roadmap phase 72, Task 6).
|
|
2
|
+
|
|
3
|
+
Hosted Roder authenticates at the WebSocket handshake with a bearer
|
|
4
|
+
credential in the ``Authorization`` header — the gateway always rejects
|
|
5
|
+
query-string credentials. :class:`HostedClient` wraps the standard RPC
|
|
6
|
+
client with typed helpers for ``hosted/*`` methods; raw JSON-RPC access
|
|
7
|
+
stays available via ``client.raw_request`` for forward-compatible hosted
|
|
8
|
+
methods.
|
|
9
|
+
|
|
10
|
+
Token refresh/reconnect: connections authenticate once at handshake time,
|
|
11
|
+
so refreshing a token means reconnecting. :meth:`HostedClient.reconnect`
|
|
12
|
+
builds a fresh transport using the token provider. Requests in flight when
|
|
13
|
+
a connection drops fail with a transport error and are NEVER replayed
|
|
14
|
+
automatically — callers retry mutating requests themselves because only
|
|
15
|
+
they know whether the operation is idempotent.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import inspect
|
|
21
|
+
from collections.abc import Awaitable, Callable
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
from .client import RoderRpcClient
|
|
25
|
+
from .transports import WebSocketTransport
|
|
26
|
+
|
|
27
|
+
TokenProvider = Callable[[], "str | Awaitable[str]"]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class HostedClient:
|
|
31
|
+
def __init__(self, options: dict[str, Any], client: RoderRpcClient) -> None:
|
|
32
|
+
self._options = options
|
|
33
|
+
self.client = client
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
async def connect(
|
|
37
|
+
cls,
|
|
38
|
+
url: str,
|
|
39
|
+
*,
|
|
40
|
+
token: str | None = None,
|
|
41
|
+
token_provider: TokenProvider | None = None,
|
|
42
|
+
headers: dict[str, str] | None = None,
|
|
43
|
+
connector: Callable[..., Awaitable[Any]] | None = None,
|
|
44
|
+
) -> "HostedClient":
|
|
45
|
+
"""Connects and authenticates against a hosted Roder gateway."""
|
|
46
|
+
options = {
|
|
47
|
+
"url": url,
|
|
48
|
+
"token": token,
|
|
49
|
+
"token_provider": token_provider,
|
|
50
|
+
"headers": headers,
|
|
51
|
+
"connector": connector,
|
|
52
|
+
}
|
|
53
|
+
client = await _hosted_rpc_client(options)
|
|
54
|
+
return cls(options, client)
|
|
55
|
+
|
|
56
|
+
async def reconnect(self) -> None:
|
|
57
|
+
"""Re-authenticates with a fresh credential and replaces the
|
|
58
|
+
connection. In-flight requests on the old connection fail; nothing
|
|
59
|
+
is replayed."""
|
|
60
|
+
next_client = await _hosted_rpc_client(self._options)
|
|
61
|
+
previous = self.client
|
|
62
|
+
self.client = next_client
|
|
63
|
+
await previous.close()
|
|
64
|
+
|
|
65
|
+
async def whoami(self) -> dict[str, Any]:
|
|
66
|
+
return await self.client.call("hosted/whoami", {})
|
|
67
|
+
|
|
68
|
+
async def create_service_account(self, display_name: str) -> dict[str, Any]:
|
|
69
|
+
return await self.client.call(
|
|
70
|
+
"hosted/service_accounts/create", {"displayName": display_name}
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
async def revoke_service_account(self, key_id: str) -> dict[str, Any]:
|
|
74
|
+
return await self.client.call("hosted/service_accounts/revoke", {"keyId": key_id})
|
|
75
|
+
|
|
76
|
+
async def list_hooks(self) -> dict[str, Any]:
|
|
77
|
+
return await self.client.call("hosted/hooks/list", {})
|
|
78
|
+
|
|
79
|
+
async def create_hook(self, hook: dict[str, Any]) -> dict[str, Any]:
|
|
80
|
+
return await self.client.call("hosted/hooks/create", {"hook": hook})
|
|
81
|
+
|
|
82
|
+
async def delete_hook(self, hook_id: str) -> dict[str, Any]:
|
|
83
|
+
return await self.client.call("hosted/hooks/delete", {"hookId": hook_id})
|
|
84
|
+
|
|
85
|
+
async def audit_list(self) -> dict[str, Any]:
|
|
86
|
+
return await self.client.call("hosted/audit/list", {})
|
|
87
|
+
|
|
88
|
+
def notifications(self):
|
|
89
|
+
return self.client.notifications()
|
|
90
|
+
|
|
91
|
+
async def close(self) -> None:
|
|
92
|
+
await self.client.close()
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
async def _resolve_token(options: dict[str, Any]) -> str | None:
|
|
96
|
+
provider: TokenProvider | None = options.get("token_provider")
|
|
97
|
+
if provider is not None:
|
|
98
|
+
token = provider()
|
|
99
|
+
if inspect.isawaitable(token):
|
|
100
|
+
token = await token
|
|
101
|
+
return str(token)
|
|
102
|
+
return options.get("token")
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
async def _hosted_rpc_client(options: dict[str, Any]) -> RoderRpcClient:
|
|
106
|
+
token = await _resolve_token(options)
|
|
107
|
+
headers: dict[str, str] | None = options.get("headers")
|
|
108
|
+
has_external_auth = bool(headers) and any(
|
|
109
|
+
key.lower() == "authorization" for key in headers
|
|
110
|
+
)
|
|
111
|
+
if not token and not has_external_auth:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
"hosted connections require a token, token_provider, or an Authorization header"
|
|
114
|
+
)
|
|
115
|
+
transport = await WebSocketTransport.connect(
|
|
116
|
+
options["url"],
|
|
117
|
+
token=token,
|
|
118
|
+
headers=headers,
|
|
119
|
+
connector=options.get("connector"),
|
|
120
|
+
)
|
|
121
|
+
return RoderRpcClient(transport)
|
roder_sdk/run.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import anyio
|
|
7
|
+
|
|
8
|
+
from .client import RoderRpcClient
|
|
9
|
+
from .events import EventMode, JsonRpcNotification, RoderSdkEvent, normalize_notification
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RoderRun:
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
client: RoderRpcClient,
|
|
16
|
+
thread_id: str,
|
|
17
|
+
turn_id: str,
|
|
18
|
+
*,
|
|
19
|
+
event_mode: EventMode = "permissive",
|
|
20
|
+
) -> None:
|
|
21
|
+
self.client = client
|
|
22
|
+
self.thread_id = thread_id
|
|
23
|
+
self.turn_id = turn_id
|
|
24
|
+
self.event_mode: EventMode = event_mode
|
|
25
|
+
self.cancel_scope = anyio.CancelScope()
|
|
26
|
+
|
|
27
|
+
async def stream(self) -> AsyncIterator[RoderSdkEvent]:
|
|
28
|
+
async for notification in self.client.notifications():
|
|
29
|
+
event = normalize_notification(notification, self.event_mode)
|
|
30
|
+
if event is not None:
|
|
31
|
+
yield event
|
|
32
|
+
if event and event["type"] == "turn.completed" and _matches_turn(notification.get("params"), self.turn_id):
|
|
33
|
+
return
|
|
34
|
+
|
|
35
|
+
def raw_events(self) -> AsyncIterator[JsonRpcNotification]:
|
|
36
|
+
return self.client.notifications()
|
|
37
|
+
|
|
38
|
+
async def wait(self) -> RoderSdkEvent | None:
|
|
39
|
+
async for event in self.stream():
|
|
40
|
+
if event["type"] == "turn.completed":
|
|
41
|
+
return event
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
async def cancel(self, reason: str = "sdk cancel") -> Any:
|
|
45
|
+
self.cancel_scope.cancel()
|
|
46
|
+
return await self.client.call(
|
|
47
|
+
"turn/interrupt",
|
|
48
|
+
{"threadId": self.thread_id, "turnId": self.turn_id, "reason": reason},
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
async def result(self) -> Any:
|
|
52
|
+
return await self.client.call("thread/read", {"threadId": self.thread_id})
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _matches_turn(params: Any, turn_id: str) -> bool:
|
|
56
|
+
if not isinstance(params, dict):
|
|
57
|
+
return True
|
|
58
|
+
turn = params.get("turn")
|
|
59
|
+
nested_id = turn.get("id") if isinstance(turn, dict) else None
|
|
60
|
+
direct_id = params.get("turnId")
|
|
61
|
+
if direct_id is None and nested_id is None:
|
|
62
|
+
return True
|
|
63
|
+
return direct_id == turn_id or nested_id == turn_id
|