agentforge-chat 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentforge_chat/__init__.py +40 -0
- agentforge_chat/_idempotency.py +38 -0
- agentforge_chat/_locks.py +115 -0
- agentforge_chat/_segment.py +45 -0
- agentforge_chat/_window.py +86 -0
- agentforge_chat/build.py +112 -0
- agentforge_chat/history.py +126 -0
- agentforge_chat/manifest.yaml +32 -0
- agentforge_chat/py.typed +0 -0
- agentforge_chat/session.py +496 -0
- agentforge_chat/sqlite.py +276 -0
- agentforge_chat/tokenisers.py +91 -0
- agentforge_chat/truncation.py +206 -0
- agentforge_chat-0.2.1.dist-info/METADATA +59 -0
- agentforge_chat-0.2.1.dist-info/RECORD +18 -0
- agentforge_chat-0.2.1.dist-info/WHEEL +4 -0
- agentforge_chat-0.2.1.dist-info/entry_points.txt +9 -0
- agentforge_chat-0.2.1.dist-info/licenses/LICENSE +202 -0
|
@@ -0,0 +1,496 @@
|
|
|
1
|
+
"""`ChatSession` — stateful conversation wrapper around `Agent`
|
|
2
|
+
(feat-020).
|
|
3
|
+
|
|
4
|
+
Lifecycle per turn (`send` / `stream`):
|
|
5
|
+
|
|
6
|
+
1. Acquire per-session lock.
|
|
7
|
+
2. Check idempotency cache.
|
|
8
|
+
3. Build user `ChatTurn` and run input guardrails.
|
|
9
|
+
4. Append user turn.
|
|
10
|
+
5. Load + truncate prior history.
|
|
11
|
+
6. Build the agent task as a serialised transcript.
|
|
12
|
+
7. `agent.run(task)`.
|
|
13
|
+
8. Run output guardrails.
|
|
14
|
+
9. Append assistant + tool turns.
|
|
15
|
+
10. Aggregate per-turn / per-session cost.
|
|
16
|
+
11. Release lock; return `ChatResponse`.
|
|
17
|
+
|
|
18
|
+
Cancellation in v0.2 is pre-LLM only (between history-load and
|
|
19
|
+
agent.run). Mid-LLM cancellation needs the same strategy-level
|
|
20
|
+
streaming work documented in feat-020 §10 deferrals.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import asyncio
|
|
26
|
+
import time
|
|
27
|
+
from collections.abc import AsyncIterator, Callable
|
|
28
|
+
from typing import Any, Literal
|
|
29
|
+
from uuid import uuid4
|
|
30
|
+
|
|
31
|
+
from agentforge.agent import Agent
|
|
32
|
+
from agentforge_core.contracts.chat import ChatHistoryStore, HistoryTruncationStrategy
|
|
33
|
+
from agentforge_core.contracts.strategy import ReasoningStrategy
|
|
34
|
+
from agentforge_core.production.exceptions import (
|
|
35
|
+
BudgetExceeded,
|
|
36
|
+
GuardrailViolation,
|
|
37
|
+
)
|
|
38
|
+
from agentforge_core.values.chat import ChatChunk, ChatResponse, ChatTurn, StreamingEvent
|
|
39
|
+
|
|
40
|
+
from agentforge_chat._idempotency import IdempotencyCache
|
|
41
|
+
from agentforge_chat._locks import (
|
|
42
|
+
SessionLockFactory,
|
|
43
|
+
default_session_lock_factory,
|
|
44
|
+
)
|
|
45
|
+
from agentforge_chat._segment import segment_for_stream
|
|
46
|
+
from agentforge_chat._window import _SentenceWindowBuffer
|
|
47
|
+
from agentforge_chat.history import InMemoryChatHistory
|
|
48
|
+
from agentforge_chat.truncation import SlidingWindow
|
|
49
|
+
|
|
50
|
+
OnTurnHook = Callable[[ChatTurn], None]
|
|
51
|
+
|
|
52
|
+
SafetyMode = Literal["buffer-then-stream", "sentence-window", "stream-then-redact"]
|
|
53
|
+
"""Output-guardrail policy on streamed assistant turns. See
|
|
54
|
+
:class:`agentforge_core.config.schema.ChatSessionConfig` for
|
|
55
|
+
per-value semantics. ``"stream-then-redact"`` is currently an
|
|
56
|
+
alias for ``"sentence-window"`` (feat-020 v0.3 polish)."""
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class ChatSession:
|
|
60
|
+
"""Wrap a one-shot `Agent` into a multi-turn chat session."""
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
agent: Agent,
|
|
65
|
+
*,
|
|
66
|
+
session_id: str | None = None,
|
|
67
|
+
history_store: ChatHistoryStore | None = None,
|
|
68
|
+
system_prompt: str | None = None,
|
|
69
|
+
truncation: HistoryTruncationStrategy | None = None,
|
|
70
|
+
owner: str | None = None,
|
|
71
|
+
per_turn_budget_usd: float | None = None,
|
|
72
|
+
per_session_budget_usd: float | None = None,
|
|
73
|
+
idempotency_window_s: float = 60.0,
|
|
74
|
+
on_turn: OnTurnHook | None = None,
|
|
75
|
+
session_lock_factory: SessionLockFactory | None = None,
|
|
76
|
+
safety_mode: SafetyMode = "buffer-then-stream",
|
|
77
|
+
) -> None:
|
|
78
|
+
self._agent = agent
|
|
79
|
+
self._session_id = session_id if session_id is not None else uuid4().hex
|
|
80
|
+
self._history: ChatHistoryStore = (
|
|
81
|
+
history_store if history_store is not None else InMemoryChatHistory()
|
|
82
|
+
)
|
|
83
|
+
self._system_prompt = system_prompt
|
|
84
|
+
self._truncation: HistoryTruncationStrategy = (
|
|
85
|
+
truncation if truncation is not None else SlidingWindow(50)
|
|
86
|
+
)
|
|
87
|
+
self._owner = owner
|
|
88
|
+
self._per_turn_budget = per_turn_budget_usd
|
|
89
|
+
self._per_session_budget = per_session_budget_usd
|
|
90
|
+
self._on_turn = on_turn
|
|
91
|
+
factory = session_lock_factory or default_session_lock_factory
|
|
92
|
+
self._lock = factory(self._session_id)
|
|
93
|
+
self._idempotency: IdempotencyCache[ChatResponse] = IdempotencyCache(
|
|
94
|
+
ttl_s=idempotency_window_s
|
|
95
|
+
)
|
|
96
|
+
self._total_cost = 0.0
|
|
97
|
+
self._turn_count = 0
|
|
98
|
+
self._closed = False
|
|
99
|
+
self._safety_mode: SafetyMode = safety_mode
|
|
100
|
+
|
|
101
|
+
# ------------------------------------------------------------------
|
|
102
|
+
# Public properties
|
|
103
|
+
# ------------------------------------------------------------------
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def session_id(self) -> str:
|
|
107
|
+
return self._session_id
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def total_cost_usd(self) -> float:
|
|
111
|
+
return self._total_cost
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def turn_count(self) -> int:
|
|
115
|
+
return self._turn_count
|
|
116
|
+
|
|
117
|
+
# ------------------------------------------------------------------
|
|
118
|
+
# Public API
|
|
119
|
+
# ------------------------------------------------------------------
|
|
120
|
+
|
|
121
|
+
async def send(
|
|
122
|
+
self,
|
|
123
|
+
message: str,
|
|
124
|
+
*,
|
|
125
|
+
idempotency_key: str | None = None,
|
|
126
|
+
cancellation: asyncio.Event | None = None,
|
|
127
|
+
) -> ChatResponse:
|
|
128
|
+
"""Send one user message and await a buffered response."""
|
|
129
|
+
async with self._lock:
|
|
130
|
+
cached = self._check_cache(idempotency_key)
|
|
131
|
+
if cached is not None:
|
|
132
|
+
return cached
|
|
133
|
+
response, _ = await self._run_turn(message, cancellation=cancellation)
|
|
134
|
+
self._stash_cache(idempotency_key, response)
|
|
135
|
+
return response
|
|
136
|
+
|
|
137
|
+
async def stream(
|
|
138
|
+
self,
|
|
139
|
+
message: str,
|
|
140
|
+
*,
|
|
141
|
+
idempotency_key: str | None = None,
|
|
142
|
+
cancellation: asyncio.Event | None = None,
|
|
143
|
+
) -> AsyncIterator[ChatChunk]:
|
|
144
|
+
"""Send one user message and stream the response back as chunks.
|
|
145
|
+
|
|
146
|
+
v0.2 uses buffer-then-stream: the agent runs to completion,
|
|
147
|
+
then the assistant turn is emitted as a sequence of text
|
|
148
|
+
chunks (sentence-segmented) followed by a `done` chunk. Real
|
|
149
|
+
per-token streaming becomes a no-API-break enhancement when
|
|
150
|
+
the strategy ABC grows a `stream()` method.
|
|
151
|
+
"""
|
|
152
|
+
return self._stream_impl(message, idempotency_key, cancellation)
|
|
153
|
+
|
|
154
|
+
async def history(
|
|
155
|
+
self,
|
|
156
|
+
*,
|
|
157
|
+
limit: int | None = None,
|
|
158
|
+
roles: list[str] | None = None,
|
|
159
|
+
) -> list[ChatTurn]:
|
|
160
|
+
return await self._history.load(self._session_id, limit=limit, roles=roles)
|
|
161
|
+
|
|
162
|
+
async def reset(self) -> None:
|
|
163
|
+
await self._history.delete_session(self._session_id)
|
|
164
|
+
self._total_cost = 0.0
|
|
165
|
+
self._turn_count = 0
|
|
166
|
+
|
|
167
|
+
async def close(self) -> None:
|
|
168
|
+
if self._closed:
|
|
169
|
+
return
|
|
170
|
+
await self._history.close()
|
|
171
|
+
self._closed = True
|
|
172
|
+
|
|
173
|
+
# ------------------------------------------------------------------
|
|
174
|
+
# Implementation helpers
|
|
175
|
+
# ------------------------------------------------------------------
|
|
176
|
+
|
|
177
|
+
def _check_cache(self, key: str | None) -> ChatResponse | None:
|
|
178
|
+
if key is None:
|
|
179
|
+
return None
|
|
180
|
+
return self._idempotency.get(self._session_id, key)
|
|
181
|
+
|
|
182
|
+
def _stash_cache(self, key: str | None, response: ChatResponse) -> None:
|
|
183
|
+
if key is None:
|
|
184
|
+
return
|
|
185
|
+
self._idempotency.put(self._session_id, key, response)
|
|
186
|
+
|
|
187
|
+
async def _run_turn(
|
|
188
|
+
self,
|
|
189
|
+
message: str,
|
|
190
|
+
*,
|
|
191
|
+
cancellation: asyncio.Event | None,
|
|
192
|
+
) -> tuple[ChatResponse, ChatTurn]:
|
|
193
|
+
ctx = self._guard_context()
|
|
194
|
+
validated_msg = await self._agent._guardrails.check_input(message, ctx)
|
|
195
|
+
user_turn = await self._build_user_turn(validated_msg)
|
|
196
|
+
if cancellation is not None and cancellation.is_set():
|
|
197
|
+
raise asyncio.CancelledError("chat turn cancelled before agent.run")
|
|
198
|
+
task = await self._compose_task(user_turn)
|
|
199
|
+
start = time.monotonic()
|
|
200
|
+
result = await self._agent.run(task)
|
|
201
|
+
duration_ms = int((time.monotonic() - start) * 1000)
|
|
202
|
+
validated_out = await self._agent._guardrails.check_output(self._extract_text(result), ctx)
|
|
203
|
+
assistant_turn = await self._persist_assistant(validated_out, result, duration_ms)
|
|
204
|
+
self._enforce_budgets(result.cost_usd)
|
|
205
|
+
response = ChatResponse(
|
|
206
|
+
content=validated_out,
|
|
207
|
+
turn_id=assistant_turn.id,
|
|
208
|
+
run_id=result.run_id,
|
|
209
|
+
tool_calls=(),
|
|
210
|
+
tokens_in=result.tokens_in,
|
|
211
|
+
tokens_out=result.tokens_out,
|
|
212
|
+
cost_usd=result.cost_usd,
|
|
213
|
+
duration_ms=duration_ms,
|
|
214
|
+
finish_reason=str(result.finish_reason),
|
|
215
|
+
)
|
|
216
|
+
return response, assistant_turn
|
|
217
|
+
|
|
218
|
+
def _guard_context(self) -> dict[str, Any]:
|
|
219
|
+
return {
|
|
220
|
+
"session_id": self._session_id,
|
|
221
|
+
"owner": self._owner or "anonymous",
|
|
222
|
+
"project": "chat",
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
async def _build_user_turn(self, content: str) -> ChatTurn:
|
|
226
|
+
turn = ChatTurn(
|
|
227
|
+
id=uuid4().hex,
|
|
228
|
+
session_id=self._session_id,
|
|
229
|
+
role="user",
|
|
230
|
+
content=content,
|
|
231
|
+
)
|
|
232
|
+
await self._history.append(turn)
|
|
233
|
+
if self._on_turn is not None:
|
|
234
|
+
self._on_turn(turn)
|
|
235
|
+
return turn
|
|
236
|
+
|
|
237
|
+
async def _compose_task(self, user_turn: ChatTurn) -> str:
|
|
238
|
+
prior = await self._history.load(self._session_id, limit=None)
|
|
239
|
+
# Drop the just-appended user turn — it's added as the final
|
|
240
|
+
# line below.
|
|
241
|
+
prior_without_current = [t for t in prior if t.id != user_turn.id]
|
|
242
|
+
kept = await self._truncation.select(prior_without_current, user_turn.content, {})
|
|
243
|
+
lines: list[str] = []
|
|
244
|
+
prompt = self._system_prompt or ""
|
|
245
|
+
if prompt:
|
|
246
|
+
lines.append(prompt)
|
|
247
|
+
lines.extend(f"{t.role}: {t.content}" for t in kept)
|
|
248
|
+
lines.append(f"user: {user_turn.content}")
|
|
249
|
+
return "\n\n".join(lines)
|
|
250
|
+
|
|
251
|
+
def _extract_text(self, result: Any) -> str:
|
|
252
|
+
output = result.output
|
|
253
|
+
if isinstance(output, str):
|
|
254
|
+
return output
|
|
255
|
+
return str(output)
|
|
256
|
+
|
|
257
|
+
async def _persist_assistant(
|
|
258
|
+
self,
|
|
259
|
+
text: str,
|
|
260
|
+
result: Any,
|
|
261
|
+
duration_ms: int,
|
|
262
|
+
) -> ChatTurn:
|
|
263
|
+
del duration_ms
|
|
264
|
+
turn = ChatTurn(
|
|
265
|
+
id=uuid4().hex,
|
|
266
|
+
session_id=self._session_id,
|
|
267
|
+
role="assistant",
|
|
268
|
+
content=text,
|
|
269
|
+
run_id=result.run_id,
|
|
270
|
+
tokens_in=result.tokens_in,
|
|
271
|
+
tokens_out=result.tokens_out,
|
|
272
|
+
cost_usd=result.cost_usd,
|
|
273
|
+
)
|
|
274
|
+
await self._history.append(turn)
|
|
275
|
+
if self._on_turn is not None:
|
|
276
|
+
self._on_turn(turn)
|
|
277
|
+
self._total_cost += result.cost_usd
|
|
278
|
+
self._turn_count += 1
|
|
279
|
+
await self._history.update_session_metadata(
|
|
280
|
+
self._session_id,
|
|
281
|
+
{"owner": self._owner, "total_cost_usd": self._total_cost},
|
|
282
|
+
)
|
|
283
|
+
return turn
|
|
284
|
+
|
|
285
|
+
def _enforce_budgets(self, turn_cost: float) -> None:
|
|
286
|
+
if self._per_turn_budget is not None and turn_cost > self._per_turn_budget:
|
|
287
|
+
raise BudgetExceeded(
|
|
288
|
+
f"chat turn cost ${turn_cost:.4f} exceeds per-turn budget "
|
|
289
|
+
f"${self._per_turn_budget:.4f}"
|
|
290
|
+
)
|
|
291
|
+
if self._per_session_budget is not None and self._total_cost > self._per_session_budget:
|
|
292
|
+
raise BudgetExceeded(
|
|
293
|
+
f"chat session total ${self._total_cost:.4f} exceeds per-session "
|
|
294
|
+
f"budget ${self._per_session_budget:.4f}"
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
def _strategy_overrides_stream(self) -> bool:
|
|
298
|
+
"""True when the agent's strategy defines its own `stream()`.
|
|
299
|
+
|
|
300
|
+
Distinguishes "real per-token streaming" from the default
|
|
301
|
+
ABC behaviour (which just wraps `run()` + emits one `done`).
|
|
302
|
+
Real per-token strategies override `stream()` to yield text /
|
|
303
|
+
tool-call events as the LLM emits them. v0.2 falls back to
|
|
304
|
+
buffer-then-stream when the override isn't there so v0.1
|
|
305
|
+
callers get the same wire shape they had before.
|
|
306
|
+
"""
|
|
307
|
+
return type(self._agent._strategy).stream is not ReasoningStrategy.stream
|
|
308
|
+
|
|
309
|
+
async def _stream_impl(
|
|
310
|
+
self,
|
|
311
|
+
message: str,
|
|
312
|
+
idempotency_key: str | None,
|
|
313
|
+
cancellation: asyncio.Event | None,
|
|
314
|
+
) -> AsyncIterator[ChatChunk]:
|
|
315
|
+
async with self._lock:
|
|
316
|
+
cached = self._check_cache(idempotency_key)
|
|
317
|
+
if cached is not None:
|
|
318
|
+
async for chunk in self._chunks_for(cached):
|
|
319
|
+
yield chunk
|
|
320
|
+
return
|
|
321
|
+
if self._strategy_overrides_stream():
|
|
322
|
+
try:
|
|
323
|
+
async for chunk in self._stream_per_token(message, cancellation=cancellation):
|
|
324
|
+
yield chunk
|
|
325
|
+
return # noqa: TRY300 — return in try is the explicit happy-path exit
|
|
326
|
+
except (BudgetExceeded, GuardrailViolation, asyncio.CancelledError) as exc:
|
|
327
|
+
yield ChatChunk(
|
|
328
|
+
kind="error",
|
|
329
|
+
turn_id=uuid4().hex,
|
|
330
|
+
content={"reason": type(exc).__name__, "message": str(exc)},
|
|
331
|
+
)
|
|
332
|
+
return
|
|
333
|
+
try:
|
|
334
|
+
response, _ = await self._run_turn(message, cancellation=cancellation)
|
|
335
|
+
except (BudgetExceeded, GuardrailViolation, asyncio.CancelledError) as exc:
|
|
336
|
+
yield ChatChunk(
|
|
337
|
+
kind="error",
|
|
338
|
+
turn_id=uuid4().hex,
|
|
339
|
+
content={"reason": type(exc).__name__, "message": str(exc)},
|
|
340
|
+
)
|
|
341
|
+
return
|
|
342
|
+
self._stash_cache(idempotency_key, response)
|
|
343
|
+
async for chunk in self._chunks_for(response):
|
|
344
|
+
yield chunk
|
|
345
|
+
|
|
346
|
+
async def _stream_per_token( # noqa: PLR0912
|
|
347
|
+
self,
|
|
348
|
+
message: str,
|
|
349
|
+
*,
|
|
350
|
+
cancellation: asyncio.Event | None,
|
|
351
|
+
) -> AsyncIterator[ChatChunk]:
|
|
352
|
+
"""Drive the agent via `agent.stream(task)` and forward every
|
|
353
|
+
`StreamingEvent` as a `ChatChunk`. Persists the user + final
|
|
354
|
+
assistant turns and updates per-session budgets the same way
|
|
355
|
+
`_run_turn` does.
|
|
356
|
+
|
|
357
|
+
When ``safety_mode == "sentence-window"`` (or its current
|
|
358
|
+
alias ``"stream-then-redact"``), `text` events are buffered
|
|
359
|
+
until a sentence boundary; each completed sentence runs
|
|
360
|
+
through ``check_output`` before being emitted to the wire.
|
|
361
|
+
Non-text events (``tool_call``, ``step``, ``error``) pass
|
|
362
|
+
through unbuffered.
|
|
363
|
+
"""
|
|
364
|
+
ctx = self._guard_context()
|
|
365
|
+
validated_msg = await self._agent._guardrails.check_input(message, ctx)
|
|
366
|
+
user_turn = await self._build_user_turn(validated_msg)
|
|
367
|
+
if cancellation is not None and cancellation.is_set():
|
|
368
|
+
raise asyncio.CancelledError("chat turn cancelled before agent.stream")
|
|
369
|
+
task = await self._compose_task(user_turn)
|
|
370
|
+
assistant_turn_id = uuid4().hex
|
|
371
|
+
cumulative = ""
|
|
372
|
+
run_summary: dict[str, Any] | None = None
|
|
373
|
+
start = time.monotonic()
|
|
374
|
+
buffered = self._safety_mode in ("sentence-window", "stream-then-redact")
|
|
375
|
+
window = _SentenceWindowBuffer() if buffered else None
|
|
376
|
+
async for event in self._agent.stream(task):
|
|
377
|
+
if event.kind == "done":
|
|
378
|
+
if isinstance(event.content, dict):
|
|
379
|
+
run_summary = event.content
|
|
380
|
+
# Don't break — let the generator yield its terminal
|
|
381
|
+
# event and complete naturally so `Agent.stream`'s
|
|
382
|
+
# `finally: reset_run(token)` fires deterministically.
|
|
383
|
+
continue
|
|
384
|
+
if event.kind == "text" and isinstance(event.content, str):
|
|
385
|
+
if window is not None:
|
|
386
|
+
for sentence in window.push(event.content):
|
|
387
|
+
validated = await self._agent._guardrails.check_output(sentence, ctx)
|
|
388
|
+
cumulative += validated
|
|
389
|
+
yield ChatChunk(
|
|
390
|
+
kind="text",
|
|
391
|
+
content=validated,
|
|
392
|
+
cumulative_text=cumulative,
|
|
393
|
+
turn_id=assistant_turn_id,
|
|
394
|
+
metadata=dict(event.metadata),
|
|
395
|
+
)
|
|
396
|
+
continue
|
|
397
|
+
cumulative += event.content
|
|
398
|
+
yield self._chunk_from_event(event, assistant_turn_id)
|
|
399
|
+
if window is not None:
|
|
400
|
+
residual = window.flush()
|
|
401
|
+
if residual:
|
|
402
|
+
validated = await self._agent._guardrails.check_output(residual, ctx)
|
|
403
|
+
cumulative += validated
|
|
404
|
+
yield ChatChunk(
|
|
405
|
+
kind="text",
|
|
406
|
+
content=validated,
|
|
407
|
+
cumulative_text=cumulative,
|
|
408
|
+
turn_id=assistant_turn_id,
|
|
409
|
+
metadata={},
|
|
410
|
+
)
|
|
411
|
+
duration_ms = int((time.monotonic() - start) * 1000)
|
|
412
|
+
if run_summary is None:
|
|
413
|
+
run_summary = {
|
|
414
|
+
"output": cumulative,
|
|
415
|
+
"run_id": uuid4().hex,
|
|
416
|
+
"cost_usd": 0.0,
|
|
417
|
+
"tokens_in": 0,
|
|
418
|
+
"tokens_out": 0,
|
|
419
|
+
"finish_reason": "completed",
|
|
420
|
+
"duration_ms": duration_ms,
|
|
421
|
+
}
|
|
422
|
+
if buffered:
|
|
423
|
+
# Each sentence already passed through `check_output`;
|
|
424
|
+
# `cumulative` is the validated text. Skip the terminal
|
|
425
|
+
# double-validation.
|
|
426
|
+
validated_out = cumulative
|
|
427
|
+
else:
|
|
428
|
+
final_text = (
|
|
429
|
+
str(run_summary.get("output", cumulative))
|
|
430
|
+
if isinstance(run_summary.get("output"), str)
|
|
431
|
+
else cumulative
|
|
432
|
+
)
|
|
433
|
+
validated_out = await self._agent._guardrails.check_output(final_text, ctx)
|
|
434
|
+
assistant_turn = ChatTurn(
|
|
435
|
+
id=assistant_turn_id,
|
|
436
|
+
session_id=self._session_id,
|
|
437
|
+
role="assistant",
|
|
438
|
+
content=validated_out,
|
|
439
|
+
run_id=str(run_summary.get("run_id", "")),
|
|
440
|
+
tokens_in=int(run_summary.get("tokens_in", 0) or 0),
|
|
441
|
+
tokens_out=int(run_summary.get("tokens_out", 0) or 0),
|
|
442
|
+
cost_usd=float(run_summary.get("cost_usd", 0.0) or 0.0),
|
|
443
|
+
)
|
|
444
|
+
await self._history.append(assistant_turn)
|
|
445
|
+
if self._on_turn is not None:
|
|
446
|
+
self._on_turn(assistant_turn)
|
|
447
|
+
self._total_cost += float(run_summary.get("cost_usd", 0.0) or 0.0)
|
|
448
|
+
self._turn_count += 1
|
|
449
|
+
await self._history.update_session_metadata(
|
|
450
|
+
self._session_id,
|
|
451
|
+
{"owner": self._owner, "total_cost_usd": self._total_cost},
|
|
452
|
+
)
|
|
453
|
+
self._enforce_budgets(float(run_summary.get("cost_usd", 0.0) or 0.0))
|
|
454
|
+
yield ChatChunk(
|
|
455
|
+
kind="done",
|
|
456
|
+
turn_id=assistant_turn_id,
|
|
457
|
+
content={
|
|
458
|
+
"run_id": str(run_summary.get("run_id", "")),
|
|
459
|
+
"cost_usd": float(run_summary.get("cost_usd", 0.0) or 0.0),
|
|
460
|
+
"tokens_in": int(run_summary.get("tokens_in", 0) or 0),
|
|
461
|
+
"tokens_out": int(run_summary.get("tokens_out", 0) or 0),
|
|
462
|
+
},
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
def _chunk_from_event(self, event: StreamingEvent, turn_id: str) -> ChatChunk:
|
|
466
|
+
return ChatChunk(
|
|
467
|
+
kind=event.kind,
|
|
468
|
+
content=event.content,
|
|
469
|
+
cumulative_text=event.cumulative_text,
|
|
470
|
+
turn_id=turn_id,
|
|
471
|
+
metadata=dict(event.metadata),
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
async def _chunks_for(self, response: ChatResponse) -> AsyncIterator[ChatChunk]:
|
|
475
|
+
cumulative = ""
|
|
476
|
+
for piece in segment_for_stream(response.content):
|
|
477
|
+
cumulative += piece
|
|
478
|
+
yield ChatChunk(
|
|
479
|
+
kind="text",
|
|
480
|
+
content=piece,
|
|
481
|
+
cumulative_text=cumulative,
|
|
482
|
+
turn_id=response.turn_id,
|
|
483
|
+
)
|
|
484
|
+
yield ChatChunk(
|
|
485
|
+
kind="done",
|
|
486
|
+
turn_id=response.turn_id,
|
|
487
|
+
content={
|
|
488
|
+
"run_id": response.run_id,
|
|
489
|
+
"cost_usd": response.cost_usd,
|
|
490
|
+
"tokens_in": response.tokens_in,
|
|
491
|
+
"tokens_out": response.tokens_out,
|
|
492
|
+
},
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
__all__ = ["ChatSession"]
|