superdialog 0.2.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. superdialog/__init__.py +55 -0
  2. superdialog/adapters/__init__.py +13 -0
  3. superdialog/adapters/fastapi.py +120 -0
  4. superdialog/adapters/livekit.py +197 -0
  5. superdialog/adapters/pipecat.py +84 -0
  6. superdialog/adapters/websocket.py +200 -0
  7. superdialog/agent.py +55 -0
  8. superdialog/agents/__init__.py +5 -0
  9. superdialog/agents/langchain_agent.py +90 -0
  10. superdialog/agents/llm_agent.py +101 -0
  11. superdialog/chat_context.py +33 -0
  12. superdialog/cli/__init__.py +9 -0
  13. superdialog/cli/main.py +256 -0
  14. superdialog/dialog_machine.py +410 -0
  15. superdialog/flow/__init__.py +17 -0
  16. superdialog/flow/_loaders.py +100 -0
  17. superdialog/flow/bootstrap.py +456 -0
  18. superdialog/flow/enums.py +14 -0
  19. superdialog/flow/loader.py +39 -0
  20. superdialog/flow/models.py +545 -0
  21. superdialog/flow_state.py +55 -0
  22. superdialog/llm/__init__.py +17 -0
  23. superdialog/llm/litellm_provider.py +78 -0
  24. superdialog/llm/provider.py +36 -0
  25. superdialog/llm/registry.py +27 -0
  26. superdialog/llm/resolver.py +36 -0
  27. superdialog/machine/__init__.py +41 -0
  28. superdialog/machine/_lang_util.py +49 -0
  29. superdialog/machine/_prompts.py +115 -0
  30. superdialog/machine/actions.py +139 -0
  31. superdialog/machine/adapters/__init__.py +6 -0
  32. superdialog/machine/adapters/base.py +83 -0
  33. superdialog/machine/adapters/llm_adapter.py +311 -0
  34. superdialog/machine/adapters/text_adapter.py +120 -0
  35. superdialog/machine/adapters/toolcall_adapter.py +541 -0
  36. superdialog/machine/composer.py +505 -0
  37. superdialog/machine/criteria.py +406 -0
  38. superdialog/machine/extractor.py +138 -0
  39. superdialog/machine/gate.py +336 -0
  40. superdialog/machine/hooks.py +86 -0
  41. superdialog/machine/machine.py +2162 -0
  42. superdialog/machine/models.py +531 -0
  43. superdialog/machine/runner.py +343 -0
  44. superdialog/machine/store.py +34 -0
  45. superdialog/machine/testing/__init__.py +16 -0
  46. superdialog/machine/testing/flow_smoke.py +161 -0
  47. superdialog/machine/testing/mock_adapter.py +136 -0
  48. superdialog/machine/testing/sample_appointment_flow.json +107 -0
  49. superdialog/machine/testing/test_self_loop_protection.py +509 -0
  50. superdialog/machine/tools/__init__.py +27 -0
  51. superdialog/machine/tools/base.py +47 -0
  52. superdialog/machine/tools/builtins.py +135 -0
  53. superdialog/machine/tools/registry.py +60 -0
  54. superdialog/machine/transitions.py +94 -0
  55. superdialog/py.typed +0 -0
  56. superdialog/session/__init__.py +19 -0
  57. superdialog/session/lock.py +52 -0
  58. superdialog/session/record.py +29 -0
  59. superdialog/session/session.py +82 -0
  60. superdialog/session/store.py +61 -0
  61. superdialog/session/worker.py +137 -0
  62. superdialog/stream.py +44 -0
  63. superdialog/tools/__init__.py +9 -0
  64. superdialog/tools/base.py +78 -0
  65. superdialog/tools/decorator.py +51 -0
  66. superdialog/tools/http_tool.py +37 -0
  67. superdialog/tools/mcp_tool.py +53 -0
  68. superdialog/tools/python_tool.py +67 -0
  69. superdialog/traversal/.gitkeep +0 -0
  70. superdialog/traversal/__init__.py +5 -0
  71. superdialog/traversal/history/.gitignore +4 -0
  72. superdialog/traversal/history/.gitkeep +0 -0
  73. superdialog/traversal/traversal.py +195 -0
  74. superdialog-0.2.0a0.dist-info/METADATA +40 -0
  75. superdialog-0.2.0a0.dist-info/RECORD +78 -0
  76. superdialog-0.2.0a0.dist-info/WHEEL +4 -0
  77. superdialog-0.2.0a0.dist-info/entry_points.txt +2 -0
  78. superdialog-0.2.0a0.dist-info/licenses/LICENSE +202 -0
@@ -0,0 +1,55 @@
1
+ """SuperDialog -- standalone dialog state machine framework."""
2
+
3
+ __version__ = "0.2.0a0"
4
+
5
+ from .agent import Agent, TurnResult
6
+ from .agents import LLMAgent
7
+ from .chat_context import ChatContext, ChatMessage
8
+ from .dialog_machine import DialogMachine
9
+ from .flow import Flow, FlowSet, create_dialog_flow
10
+ from .flow_state import FlowState
11
+ from .llm.registry import register_llm_provider
12
+ from .session import (
13
+ AsyncioLockBackend,
14
+ InMemorySessionStore,
15
+ LockBackend,
16
+ NullSessionStore,
17
+ Session,
18
+ SessionHandle,
19
+ SessionRecord,
20
+ SessionStore,
21
+ SessionWorker,
22
+ )
23
+ from .stream import StreamChunk, ToolCall, Turn
24
+ from .tools import HttpTool, MCPTool, PythonTool, Tool, ToolResult
25
+
26
+ __all__ = [
27
+ "Agent",
28
+ "AsyncioLockBackend",
29
+ "ChatContext",
30
+ "ChatMessage",
31
+ "DialogMachine",
32
+ "Flow",
33
+ "FlowSet",
34
+ "FlowState",
35
+ "HttpTool",
36
+ "InMemorySessionStore",
37
+ "LLMAgent",
38
+ "LockBackend",
39
+ "MCPTool",
40
+ "NullSessionStore",
41
+ "PythonTool",
42
+ "Session",
43
+ "SessionHandle",
44
+ "SessionRecord",
45
+ "SessionStore",
46
+ "SessionWorker",
47
+ "StreamChunk",
48
+ "Tool",
49
+ "ToolCall",
50
+ "ToolResult",
51
+ "Turn",
52
+ "TurnResult",
53
+ "create_dialog_flow",
54
+ "register_llm_provider",
55
+ ]
@@ -0,0 +1,13 @@
1
+ """Host adapters for :class:`superdialog.DialogMachine`.
2
+
3
+ Each adapter is gated behind its own optional extra so the core package
4
+ stays import-light:
5
+
6
+ * :mod:`superdialog.adapters.livekit` -- ``pip install superdialog[livekit]``
7
+ * :mod:`superdialog.adapters.pipecat` -- ``pip install superdialog[pipecat]``
8
+ * :mod:`superdialog.adapters.fastapi` -- ``pip install superdialog[fastapi]``
9
+ * :mod:`superdialog.adapters.websocket` -- ``pip install superdialog[ws]``
10
+
11
+ Import the submodule you need; this package intentionally does not eagerly
12
+ import any adapter to keep the dependency surface minimal.
13
+ """
@@ -0,0 +1,120 @@
1
+ """FastAPI adapter for any superdialog :class:`Agent`.
2
+
3
+ Provides :func:`make_router` (an ``APIRouter`` factory) and a
4
+ :class:`FastAPIRouter` convenience wrapper that knows how to mount the
5
+ router onto an existing :class:`fastapi.FastAPI` app.
6
+
7
+ Endpoints:
8
+
9
+ * ``POST /turn`` -- run one synchronous turn and return ``{reply, metadata}``.
10
+ * ``POST /stream`` -- run one streaming turn as Server-Sent Events.
11
+ * ``POST /assist`` -- queue a system-level steering instruction.
12
+ * ``POST /reset`` -- drop conversation memory (Agent.reset when available).
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import logging
19
+ from typing import TYPE_CHECKING, Any
20
+
21
+ from superdialog.stream import StreamChunk
22
+
23
+ if TYPE_CHECKING: # pragma: no cover
24
+ from superdialog.agent import Agent
25
+ else:
26
+ Agent = Any
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def _require_fastapi() -> tuple[Any, Any, Any]:
32
+ try:
33
+ from fastapi import APIRouter, FastAPI # type: ignore
34
+ from fastapi.responses import StreamingResponse # type: ignore
35
+ except ImportError as e:
36
+ raise RuntimeError(
37
+ "FastAPI adapter requires the fastapi extra: "
38
+ "`pip install superdialog[fastapi]`"
39
+ ) from e
40
+ return APIRouter, FastAPI, StreamingResponse
41
+
42
+
43
+ def make_router(agent: Agent) -> Any:
44
+ """Return an ``APIRouter`` exposing the standard agent endpoints.
45
+
46
+ Accepts any superdialog :class:`Agent`. ``/turn`` and ``/stream``
47
+ are available unconditionally; ``/assist`` and ``/reset`` are wired
48
+ in only when the underlying agent supports them (``DialogMachine``,
49
+ ``LLMAgent`` do; bare protocol implementations may not).
50
+ """
51
+ APIRouter, _FastAPI, StreamingResponse = _require_fastapi()
52
+ router = APIRouter()
53
+
54
+ @router.post("/turn")
55
+ async def handle_turn(payload: dict[str, Any]) -> dict[str, Any]:
56
+ text = payload.get("text", "")
57
+ context = payload.get("context") or None
58
+ # Forward `context` only to agents that accept it (DialogMachine);
59
+ # generic Agent.turn signature only takes (text, *, stream).
60
+ try:
61
+ turn = await agent.turn(text, context=context) # type: ignore[call-arg]
62
+ except TypeError:
63
+ turn = await agent.turn(text)
64
+ return {"reply": turn.text, "metadata": turn.metadata}
65
+
66
+ @router.post("/stream")
67
+ async def handle_stream(payload: dict[str, Any]) -> Any:
68
+ text = payload.get("text", "")
69
+ context = payload.get("context") or None
70
+ try:
71
+ stream = await agent.turn(text, context=context, stream=True) # type: ignore[call-arg]
72
+ except TypeError:
73
+ stream = await agent.turn(text, stream=True)
74
+
75
+ async def sse() -> Any:
76
+ async for chunk in stream: # type: ignore[union-attr]
77
+ yield _sse_event(chunk)
78
+
79
+ return StreamingResponse(sse(), media_type="text/event-stream")
80
+
81
+ @router.post("/assist")
82
+ async def handle_assist(payload: dict[str, Any]) -> dict[str, str]:
83
+ text = payload.get("text", "")
84
+ assist_fn = getattr(agent, "assist", None)
85
+ if assist_fn is None:
86
+ return {"status": "unsupported"}
87
+ assist_fn(text)
88
+ return {"status": "ok"}
89
+
90
+ @router.post("/reset")
91
+ async def handle_reset() -> dict[str, str]:
92
+ reset_fn = getattr(agent, "reset", None)
93
+ if reset_fn is None:
94
+ return {"status": "unsupported"}
95
+ reset_fn()
96
+ return {"status": "ok"}
97
+
98
+ return router
99
+
100
+
101
+ def _sse_event(chunk: StreamChunk) -> str:
102
+ payload = {"text": chunk.text, "done": chunk.done}
103
+ if chunk.done and chunk.turn is not None:
104
+ payload["metadata"] = chunk.turn.metadata
105
+ return f"data: {json.dumps(payload)}\n\n"
106
+
107
+
108
+ class FastAPIRouter:
109
+ """Mountable router that exposes a superdialog :class:`Agent` over HTTP."""
110
+
111
+ def __init__(self, agent: Agent) -> None:
112
+ self.agent = agent
113
+ self.router = make_router(agent)
114
+
115
+ def mount(self, app: Any, prefix: str = "") -> None:
116
+ """Attach the router to ``app`` (a ``FastAPI`` instance)."""
117
+ app.include_router(self.router, prefix=prefix)
118
+
119
+
120
+ __all__ = ["FastAPIRouter", "make_router"]
@@ -0,0 +1,197 @@
1
+ """LiveKit Agents adapter for any superdialog :class:`Agent`.
2
+
3
+ Mirrors the ``livekit-plugins-langchain`` pattern: expose a class that
4
+ quacks like ``livekit.agents.llm.LLM`` so a LiveKit ``Agent`` can use any
5
+ superdialog Agent (``DialogMachine``, ``LLMAgent``, ``LangChainAgent``)
6
+ as its turn engine.
7
+
8
+ PORT NOTE: confidence-driven barge-in interop (VAD end-of-speech signals
9
+ piped into a richer streaming protocol) is a v0.4 follow-up. The adapter
10
+ currently consumes ``Agent.turn(text, stream=True)`` and surfaces tokens
11
+ as LiveKit ``ChatChunk`` frames.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import logging
17
+ from typing import TYPE_CHECKING, Any, AsyncIterator
18
+
19
+ from superdialog.stream import StreamChunk
20
+
21
+ if TYPE_CHECKING: # pragma: no cover - only for static type checkers
22
+ from superdialog.agent import Agent
23
+ else:
24
+ Agent = Any # runtime alias so annotations don't import the protocol
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def _require_livekit() -> Any:
30
+ """Return the ``livekit.agents.llm`` module or raise a friendly error."""
31
+ try:
32
+ from livekit.agents import llm as lk_llm # type: ignore
33
+ except ImportError as e:
34
+ raise RuntimeError(
35
+ "DialogMachineLLM requires the livekit extra: "
36
+ "`pip install superdialog[livekit]`"
37
+ ) from e
38
+ return lk_llm
39
+
40
+
41
+ class DialogMachineLLM:
42
+ """LiveKit ``Agent(llm=...)`` adapter backed by any superdialog :class:`Agent`.
43
+
44
+ Usage::
45
+
46
+ from livekit.agents import Agent as LKAgent, AgentSession
47
+ from superdialog import DialogMachine
48
+ from superdialog.adapters.livekit import DialogMachineLLM
49
+
50
+ dm = DialogMachine(flow=flow, llm="openai/gpt-5.1")
51
+ lk_agent = LKAgent(llm=DialogMachineLLM(dm))
52
+ await session.start(agent=lk_agent)
53
+
54
+ Any superdialog ``Agent`` (``DialogMachine``, ``LLMAgent``,
55
+ ``LangChainAgent``, or a custom Protocol implementation) works in
56
+ place of ``dm``. The adapter duck-types LiveKit's LLM protocol; the
57
+ precise method shape depends on the ``livekit-agents`` version
58
+ pinned in the ``[livekit]`` extra. Construction is lazy: importing
59
+ this module without livekit installed is allowed, only instantiation
60
+ fails.
61
+ """
62
+
63
+ def __init__(self, agent: Agent) -> None:
64
+ _require_livekit() # fail-fast with a clear error message
65
+ self.agent = agent
66
+
67
+ def chat(
68
+ self,
69
+ *,
70
+ chat_ctx: Any,
71
+ fnc_ctx: Any = None,
72
+ conn_options: Any = None,
73
+ **kwargs: Any,
74
+ ) -> "DialogMachineStream":
75
+ """Return a streaming response object LiveKit can iterate over."""
76
+ return DialogMachineStream(
77
+ agent=self.agent,
78
+ chat_ctx=chat_ctx,
79
+ fnc_ctx=fnc_ctx,
80
+ )
81
+
82
+
83
+ class DialogMachineStream:
84
+ """Async iterator that drives a single :class:`Agent` turn.
85
+
86
+ Pulls the latest user text out of LiveKit's ``ChatContext``, runs
87
+ ``agent.turn(text, stream=True)`` and yields LiveKit-shaped
88
+ ``ChatChunk`` frames. Falls back to plain dict frames when the
89
+ livekit-agents version does not expose a public ``ChatChunk``
90
+ constructor we can rely on.
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ agent: Agent,
96
+ chat_ctx: Any,
97
+ fnc_ctx: Any = None,
98
+ ) -> None:
99
+ self._agent = agent
100
+ self._chat_ctx = chat_ctx
101
+ self._fnc_ctx = fnc_ctx
102
+ self._iter: AsyncIterator[StreamChunk] | None = None
103
+
104
+ async def _ensure_iter(self) -> AsyncIterator[StreamChunk]:
105
+ if self._iter is not None:
106
+ return self._iter
107
+ user_text = _extract_latest_user_text(self._chat_ctx)
108
+ stream = await self._agent.turn(user_text, stream=True)
109
+ assert hasattr(
110
+ stream, "__aiter__"
111
+ ), "Agent.turn(stream=True) must yield an async iterator"
112
+ self._iter = stream # type: ignore[assignment]
113
+ return self._iter
114
+
115
+ def __aiter__(self) -> "DialogMachineStream":
116
+ return self
117
+
118
+ async def __anext__(self) -> Any:
119
+ iterator = await self._ensure_iter()
120
+ try:
121
+ chunk: StreamChunk = await iterator.__anext__()
122
+ except StopAsyncIteration:
123
+ raise
124
+ return _to_chat_chunk(chunk)
125
+
126
+ async def aclose(self) -> None: # pragma: no cover - tested via stop
127
+ iterator = self._iter
128
+ self._iter = None
129
+ if iterator is not None and hasattr(iterator, "aclose"):
130
+ await iterator.aclose()
131
+
132
+
133
+ def _extract_latest_user_text(chat_ctx: Any) -> str:
134
+ """Pull the most recent user message out of a LiveKit ChatContext.
135
+
136
+ Tolerates the small API shifts between livekit-agents minor versions
137
+ by trying a sequence of attribute paths and falling back to ``str``.
138
+ """
139
+ if chat_ctx is None:
140
+ return ""
141
+ messages = (
142
+ getattr(chat_ctx, "messages", None) or getattr(chat_ctx, "items", None) or []
143
+ )
144
+ for msg in reversed(list(messages)):
145
+ role = getattr(msg, "role", None) or (
146
+ msg.get("role") if isinstance(msg, dict) else None
147
+ )
148
+ if role != "user":
149
+ continue
150
+ content = (
151
+ getattr(msg, "content", None)
152
+ if not isinstance(msg, dict)
153
+ else msg.get("content")
154
+ )
155
+ if isinstance(content, list):
156
+ parts = [
157
+ p if isinstance(p, str) else getattr(p, "text", "") for p in content
158
+ ]
159
+ return "".join(parts)
160
+ if isinstance(content, str):
161
+ return content
162
+ return ""
163
+
164
+
165
+ def _to_chat_chunk(chunk: StreamChunk) -> Any:
166
+ """Render a :class:`StreamChunk` as a LiveKit ``ChatChunk``.
167
+
168
+ Older livekit-agents releases expose ``ChatChunk(request_id, delta)``;
169
+ newer ones use keyword-only constructors. When the symbol is missing
170
+ we fall back to a plain dict — LiveKit's runtime accepts dicts in
171
+ several call sites and tests can assert on shape directly.
172
+ """
173
+ lk_llm = _require_livekit()
174
+ chat_chunk_cls = getattr(lk_llm, "ChatChunk", None)
175
+ delta_cls = getattr(lk_llm, "ChoiceDelta", None)
176
+ if chat_chunk_cls is None:
177
+ return {"text": chunk.text, "done": chunk.done}
178
+ # Try a sequence of constructor shapes that have shipped across
179
+ # livekit-agents versions. Fall back to a plain dict if every shape
180
+ # raises (ValidationError, TypeError, etc.).
181
+ candidates: list[dict[str, Any]] = []
182
+ if delta_cls is not None:
183
+ delta_obj = delta_cls(role="assistant", content=chunk.text)
184
+ candidates.append({"id": "", "delta": delta_obj})
185
+ candidates.append({"request_id": "", "delta": delta_obj})
186
+ candidates.append({"id": "", "delta": {"content": chunk.text}})
187
+ candidates.append({"request_id": "", "delta": {"content": chunk.text}})
188
+ candidates.append({"content": chunk.text})
189
+ for kwargs in candidates:
190
+ try:
191
+ return chat_chunk_cls(**kwargs)
192
+ except Exception: # noqa: BLE001 - probe across LK versions
193
+ continue
194
+ return {"text": chunk.text, "done": chunk.done}
195
+
196
+
197
+ __all__ = ["DialogMachineLLM", "DialogMachineStream"]
@@ -0,0 +1,84 @@
1
+ """PipeCat adapter: a ``FrameProcessor`` that runs any superdialog :class:`Agent`.
2
+
3
+ PipeCat plumbs frames between processors; we accept ``TextFrame`` items
4
+ on the inbound side, drive a single ``Agent.turn(text)``, and emit
5
+ ``TextFrame`` items on the outbound side. Subclassing ``FrameProcessor``
6
+ is deferred so importing this module without the ``pipecat`` extra is
7
+ safe; only :func:`make_processor` (and class instantiation) requires it.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import logging
13
+ from typing import TYPE_CHECKING, Any
14
+
15
+ if TYPE_CHECKING: # pragma: no cover
16
+ from superdialog.agent import Agent
17
+ else:
18
+ Agent = Any
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def _require_pipecat() -> tuple[Any, Any]:
24
+ """Return ``(FrameProcessor, TextFrame)`` or raise a friendly error."""
25
+ try:
26
+ from pipecat.frames.frames import TextFrame # type: ignore
27
+ from pipecat.processors.frame_processor import FrameProcessor # type: ignore
28
+ except ImportError as e:
29
+ raise RuntimeError(
30
+ "DialogMachineProcessor requires the pipecat extra: "
31
+ "`pip install superdialog[pipecat]`"
32
+ ) from e
33
+ return FrameProcessor, TextFrame
34
+
35
+
36
+ def make_processor(agent: Agent) -> Any:
37
+ """Return a PipeCat ``FrameProcessor`` bound to ``agent``.
38
+
39
+ Accepts any superdialog :class:`Agent` (``DialogMachine``,
40
+ ``LLMAgent``, ``LangChainAgent``, or a custom implementation).
41
+
42
+ PipeCat's processor API has rotated between releases; rather than
43
+ bake a specific signature into our import graph we synthesise a
44
+ subclass at call time from whatever ``FrameProcessor`` is installed.
45
+ """
46
+ FrameProcessor, TextFrame = _require_pipecat()
47
+
48
+ class DialogMachineProcessor(FrameProcessor): # type: ignore[misc, valid-type]
49
+ """Concrete processor driving ``agent.turn`` per text frame."""
50
+
51
+ def __init__(self) -> None:
52
+ super().__init__()
53
+ self._agent = agent
54
+
55
+ async def process_frame(
56
+ self,
57
+ frame: Any,
58
+ direction: Any = None,
59
+ ) -> None:
60
+ # Defer to parent for upstream/control frames it knows about.
61
+ parent_process = getattr(super(), "process_frame", None)
62
+ if parent_process is not None:
63
+ try:
64
+ await parent_process(frame, direction)
65
+ except TypeError:
66
+ await parent_process(frame)
67
+ if not isinstance(frame, TextFrame):
68
+ return
69
+ user_text = getattr(frame, "text", "") or ""
70
+ if not user_text:
71
+ return
72
+ turn = await self._agent.turn(user_text)
73
+ push = getattr(self, "push_frame", None)
74
+ if push is None:
75
+ return
76
+ try:
77
+ await push(TextFrame(turn.text), direction)
78
+ except TypeError:
79
+ await push(TextFrame(turn.text))
80
+
81
+ return DialogMachineProcessor()
82
+
83
+
84
+ __all__ = ["make_processor"]
@@ -0,0 +1,200 @@
1
+ """WebSocket runner for any superdialog :class:`Agent` or :class:`SessionWorker`.
2
+
3
+ Two modes:
4
+
5
+ * **Single-tenant** (``WebSocketRunner(agent=...)``) -- one Agent multiplexed
6
+ across all connections. State is shared; use only for demos / single-user
7
+ servers.
8
+ * **Multi-tenant** (``WebSocketRunner(worker=...)``) -- session-keyed. Every
9
+ inbound frame must carry a ``session_id``; the runner acquires it on the
10
+ bound :class:`SessionWorker` and routes the message through the matching
11
+ :class:`SessionHandle`. State is isolated per session_id and persisted via
12
+ whatever ``SessionStore`` the worker is configured with.
13
+
14
+ Protocol (JSON messages over a single WS connection):
15
+
16
+ * Client → server::
17
+
18
+ # single-tenant mode
19
+ {"type": "user_text", "text": "..."}
20
+ {"type": "assist", "text": "..."}
21
+ {"type": "reset"}
22
+
23
+ # multi-tenant mode (session_id required on every frame)
24
+ {"type": "user_text", "session_id": "user-42", "text": "..."}
25
+ {"type": "assist", "session_id": "user-42", "text": "..."}
26
+
27
+ * Server → client::
28
+
29
+ {"type": "agent_chunk", "text": "...", "done": false}
30
+ {"type": "agent_chunk", "text": "...", "done": true, "metadata": {...}}
31
+ {"type": "assist_ack"}
32
+ {"type": "reset_ack"}
33
+ {"type": "error", "message": "..."}
34
+
35
+ Streaming is preferred; for single-shot replies callers can ignore all
36
+ but the final chunk (it carries the full ``metadata``).
37
+ """
38
+
39
+ from __future__ import annotations
40
+
41
+ import asyncio
42
+ import json
43
+ import logging
44
+ from typing import TYPE_CHECKING, Any
45
+
46
+ if TYPE_CHECKING: # pragma: no cover
47
+ from superdialog.agent import Agent
48
+ from superdialog.session import SessionWorker
49
+ else:
50
+ Agent = Any
51
+ SessionWorker = Any
52
+
53
+ logger = logging.getLogger(__name__)
54
+
55
+
56
+ def _require_websockets() -> Any:
57
+ try:
58
+ import websockets # type: ignore
59
+ except ImportError as e:
60
+ raise RuntimeError(
61
+ "WebSocketRunner requires the ws extra: `pip install superdialog[ws]`"
62
+ ) from e
63
+ return websockets
64
+
65
+
66
+ class WebSocketRunner:
67
+ """Serve a superdialog :class:`Agent` or :class:`SessionWorker` over WS.
68
+
69
+ Exactly one of ``agent`` / ``worker`` must be supplied. Mixing the two
70
+ constructor forms raises ``ValueError`` at construction time so callers
71
+ fail fast.
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ agent: Agent | None = None,
77
+ agent_id: str = "default",
78
+ api_key: str | None = None,
79
+ *,
80
+ worker: SessionWorker | None = None,
81
+ ) -> None:
82
+ if (agent is None) == (worker is None):
83
+ raise ValueError(
84
+ "WebSocketRunner requires exactly one of `agent` or `worker`"
85
+ )
86
+ self.agent = agent
87
+ self.worker = worker
88
+ self.agent_id = agent_id
89
+ self.api_key = api_key
90
+
91
+ # ------------------------------------------------------------------
92
+ # Message dispatch
93
+ # ------------------------------------------------------------------
94
+
95
+ async def handle_message(self, ws: Any, raw: str) -> None:
96
+ """Process one inbound JSON frame and stream the response."""
97
+ try:
98
+ msg = json.loads(raw)
99
+ except json.JSONDecodeError:
100
+ await ws.send(json.dumps({"type": "error", "message": "invalid_json"}))
101
+ return
102
+ if self.worker is not None:
103
+ await self._handle_with_worker(ws, msg)
104
+ else:
105
+ await self._handle_with_agent(ws, msg)
106
+
107
+ async def _handle_with_agent(self, ws: Any, msg: dict[str, Any]) -> None:
108
+ """Single-tenant dispatch -- everything routes through ``self.agent``."""
109
+ kind = msg.get("type")
110
+ if kind == "user_text":
111
+ await self._stream_reply(ws, self.agent, msg.get("text", ""))
112
+ elif kind == "assist":
113
+ await self._call_assist(ws, self.agent, msg.get("text", ""))
114
+ elif kind == "reset":
115
+ await self._call_reset(ws, self.agent)
116
+ else:
117
+ await ws.send(
118
+ json.dumps({"type": "error", "message": f"unknown_type:{kind}"})
119
+ )
120
+
121
+ async def _handle_with_worker(self, ws: Any, msg: dict[str, Any]) -> None:
122
+ """Multi-tenant dispatch -- every frame must carry ``session_id``."""
123
+ session_id = msg.get("session_id")
124
+ if not session_id:
125
+ await ws.send(
126
+ json.dumps({"type": "error", "message": "missing_session_id"})
127
+ )
128
+ return
129
+ kind = msg.get("type")
130
+ async with self.worker.acquire(session_id) as handle:
131
+ if kind == "user_text":
132
+ await self._stream_reply(ws, handle, msg.get("text", ""))
133
+ elif kind == "assist":
134
+ handle.assist(msg.get("text", ""))
135
+ await ws.send(json.dumps({"type": "assist_ack"}))
136
+ elif kind == "reset":
137
+ # SessionHandle has no reset(); fall back to the underlying agent.
138
+ await self._call_reset(ws, handle.agent)
139
+ else:
140
+ await ws.send(
141
+ json.dumps({"type": "error", "message": f"unknown_type:{kind}"})
142
+ )
143
+
144
+ # ------------------------------------------------------------------
145
+ # Per-target helpers (work against Agent OR SessionHandle)
146
+ # ------------------------------------------------------------------
147
+
148
+ async def _stream_reply(self, ws: Any, target: Any, text: str) -> None:
149
+ stream = await target.turn(text, stream=True)
150
+ async for chunk in stream: # type: ignore[union-attr]
151
+ frame: dict[str, Any] = {
152
+ "type": "agent_chunk",
153
+ "text": chunk.text,
154
+ "done": chunk.done,
155
+ }
156
+ if chunk.done and chunk.turn is not None:
157
+ frame["metadata"] = chunk.turn.metadata
158
+ await ws.send(json.dumps(frame))
159
+
160
+ async def _call_assist(self, ws: Any, target: Any, text: str) -> None:
161
+ assist_fn = getattr(target, "assist", None)
162
+ if assist_fn is None:
163
+ await ws.send(
164
+ json.dumps({"type": "error", "message": "assist_unsupported"})
165
+ )
166
+ return
167
+ assist_fn(text)
168
+ await ws.send(json.dumps({"type": "assist_ack"}))
169
+
170
+ async def _call_reset(self, ws: Any, target: Any) -> None:
171
+ reset_fn = getattr(target, "reset", None)
172
+ if reset_fn is None:
173
+ await ws.send(json.dumps({"type": "error", "message": "reset_unsupported"}))
174
+ return
175
+ reset_fn()
176
+ await ws.send(json.dumps({"type": "reset_ack"}))
177
+
178
+ # ------------------------------------------------------------------
179
+ # Connection lifecycle
180
+ # ------------------------------------------------------------------
181
+
182
+ async def _handler(self, ws: Any, *_: Any) -> None:
183
+ try:
184
+ async for raw in ws:
185
+ await self.handle_message(ws, raw)
186
+ except Exception as exc: # noqa: BLE001 - protocol-level safety net
187
+ logger.exception("websocket handler crashed: %s", exc)
188
+
189
+ def serve(self, host: str = "0.0.0.0", port: int = 8080) -> None: # nosec B104
190
+ """Block on an asyncio event loop serving the WS endpoint."""
191
+ websockets = _require_websockets()
192
+
193
+ async def _main() -> None:
194
+ async with websockets.serve(self._handler, host, port):
195
+ await asyncio.Future()
196
+
197
+ asyncio.run(_main())
198
+
199
+
200
+ __all__ = ["WebSocketRunner"]