agentforge-chat 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,496 @@
1
+ """`ChatSession` — stateful conversation wrapper around `Agent`
2
+ (feat-020).
3
+
4
+ Lifecycle per turn (`send` / `stream`):
5
+
6
+ 1. Acquire per-session lock.
7
+ 2. Check idempotency cache.
8
+ 3. Build user `ChatTurn` and run input guardrails.
9
+ 4. Append user turn.
10
+ 5. Load + truncate prior history.
11
+ 6. Build the agent task as a serialised transcript.
12
+ 7. `agent.run(task)`.
13
+ 8. Run output guardrails.
14
+ 9. Append assistant + tool turns.
15
+ 10. Aggregate per-turn / per-session cost.
16
+ 11. Release lock; return `ChatResponse`.
17
+
18
+ Cancellation in v0.2 is pre-LLM only (between history-load and
19
+ agent.run). Mid-LLM cancellation needs the same strategy-level
20
+ streaming work documented in feat-020 §10 deferrals.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import asyncio
26
+ import time
27
+ from collections.abc import AsyncIterator, Callable
28
+ from typing import Any, Literal
29
+ from uuid import uuid4
30
+
31
+ from agentforge.agent import Agent
32
+ from agentforge_core.contracts.chat import ChatHistoryStore, HistoryTruncationStrategy
33
+ from agentforge_core.contracts.strategy import ReasoningStrategy
34
+ from agentforge_core.production.exceptions import (
35
+ BudgetExceeded,
36
+ GuardrailViolation,
37
+ )
38
+ from agentforge_core.values.chat import ChatChunk, ChatResponse, ChatTurn, StreamingEvent
39
+
40
+ from agentforge_chat._idempotency import IdempotencyCache
41
+ from agentforge_chat._locks import (
42
+ SessionLockFactory,
43
+ default_session_lock_factory,
44
+ )
45
+ from agentforge_chat._segment import segment_for_stream
46
+ from agentforge_chat._window import _SentenceWindowBuffer
47
+ from agentforge_chat.history import InMemoryChatHistory
48
+ from agentforge_chat.truncation import SlidingWindow
49
+
50
+ OnTurnHook = Callable[[ChatTurn], None]
51
+
52
+ SafetyMode = Literal["buffer-then-stream", "sentence-window", "stream-then-redact"]
53
+ """Output-guardrail policy on streamed assistant turns. See
54
+ :class:`agentforge_core.config.schema.ChatSessionConfig` for
55
+ per-value semantics. ``"stream-then-redact"`` is currently an
56
+ alias for ``"sentence-window"`` (feat-020 v0.3 polish)."""
57
+
58
+
59
+ class ChatSession:
60
+ """Wrap a one-shot `Agent` into a multi-turn chat session."""
61
+
62
+ def __init__(
63
+ self,
64
+ agent: Agent,
65
+ *,
66
+ session_id: str | None = None,
67
+ history_store: ChatHistoryStore | None = None,
68
+ system_prompt: str | None = None,
69
+ truncation: HistoryTruncationStrategy | None = None,
70
+ owner: str | None = None,
71
+ per_turn_budget_usd: float | None = None,
72
+ per_session_budget_usd: float | None = None,
73
+ idempotency_window_s: float = 60.0,
74
+ on_turn: OnTurnHook | None = None,
75
+ session_lock_factory: SessionLockFactory | None = None,
76
+ safety_mode: SafetyMode = "buffer-then-stream",
77
+ ) -> None:
78
+ self._agent = agent
79
+ self._session_id = session_id if session_id is not None else uuid4().hex
80
+ self._history: ChatHistoryStore = (
81
+ history_store if history_store is not None else InMemoryChatHistory()
82
+ )
83
+ self._system_prompt = system_prompt
84
+ self._truncation: HistoryTruncationStrategy = (
85
+ truncation if truncation is not None else SlidingWindow(50)
86
+ )
87
+ self._owner = owner
88
+ self._per_turn_budget = per_turn_budget_usd
89
+ self._per_session_budget = per_session_budget_usd
90
+ self._on_turn = on_turn
91
+ factory = session_lock_factory or default_session_lock_factory
92
+ self._lock = factory(self._session_id)
93
+ self._idempotency: IdempotencyCache[ChatResponse] = IdempotencyCache(
94
+ ttl_s=idempotency_window_s
95
+ )
96
+ self._total_cost = 0.0
97
+ self._turn_count = 0
98
+ self._closed = False
99
+ self._safety_mode: SafetyMode = safety_mode
100
+
101
+ # ------------------------------------------------------------------
102
+ # Public properties
103
+ # ------------------------------------------------------------------
104
+
105
+ @property
106
+ def session_id(self) -> str:
107
+ return self._session_id
108
+
109
+ @property
110
+ def total_cost_usd(self) -> float:
111
+ return self._total_cost
112
+
113
+ @property
114
+ def turn_count(self) -> int:
115
+ return self._turn_count
116
+
117
+ # ------------------------------------------------------------------
118
+ # Public API
119
+ # ------------------------------------------------------------------
120
+
121
+ async def send(
122
+ self,
123
+ message: str,
124
+ *,
125
+ idempotency_key: str | None = None,
126
+ cancellation: asyncio.Event | None = None,
127
+ ) -> ChatResponse:
128
+ """Send one user message and await a buffered response."""
129
+ async with self._lock:
130
+ cached = self._check_cache(idempotency_key)
131
+ if cached is not None:
132
+ return cached
133
+ response, _ = await self._run_turn(message, cancellation=cancellation)
134
+ self._stash_cache(idempotency_key, response)
135
+ return response
136
+
137
+ async def stream(
138
+ self,
139
+ message: str,
140
+ *,
141
+ idempotency_key: str | None = None,
142
+ cancellation: asyncio.Event | None = None,
143
+ ) -> AsyncIterator[ChatChunk]:
144
+ """Send one user message and stream the response back as chunks.
145
+
146
+ v0.2 uses buffer-then-stream: the agent runs to completion,
147
+ then the assistant turn is emitted as a sequence of text
148
+ chunks (sentence-segmented) followed by a `done` chunk. Real
149
+ per-token streaming becomes a no-API-break enhancement when
150
+ the strategy ABC grows a `stream()` method.
151
+ """
152
+ return self._stream_impl(message, idempotency_key, cancellation)
153
+
154
+ async def history(
155
+ self,
156
+ *,
157
+ limit: int | None = None,
158
+ roles: list[str] | None = None,
159
+ ) -> list[ChatTurn]:
160
+ return await self._history.load(self._session_id, limit=limit, roles=roles)
161
+
162
+ async def reset(self) -> None:
163
+ await self._history.delete_session(self._session_id)
164
+ self._total_cost = 0.0
165
+ self._turn_count = 0
166
+
167
+ async def close(self) -> None:
168
+ if self._closed:
169
+ return
170
+ await self._history.close()
171
+ self._closed = True
172
+
173
+ # ------------------------------------------------------------------
174
+ # Implementation helpers
175
+ # ------------------------------------------------------------------
176
+
177
+ def _check_cache(self, key: str | None) -> ChatResponse | None:
178
+ if key is None:
179
+ return None
180
+ return self._idempotency.get(self._session_id, key)
181
+
182
+ def _stash_cache(self, key: str | None, response: ChatResponse) -> None:
183
+ if key is None:
184
+ return
185
+ self._idempotency.put(self._session_id, key, response)
186
+
187
+ async def _run_turn(
188
+ self,
189
+ message: str,
190
+ *,
191
+ cancellation: asyncio.Event | None,
192
+ ) -> tuple[ChatResponse, ChatTurn]:
193
+ ctx = self._guard_context()
194
+ validated_msg = await self._agent._guardrails.check_input(message, ctx)
195
+ user_turn = await self._build_user_turn(validated_msg)
196
+ if cancellation is not None and cancellation.is_set():
197
+ raise asyncio.CancelledError("chat turn cancelled before agent.run")
198
+ task = await self._compose_task(user_turn)
199
+ start = time.monotonic()
200
+ result = await self._agent.run(task)
201
+ duration_ms = int((time.monotonic() - start) * 1000)
202
+ validated_out = await self._agent._guardrails.check_output(self._extract_text(result), ctx)
203
+ assistant_turn = await self._persist_assistant(validated_out, result, duration_ms)
204
+ self._enforce_budgets(result.cost_usd)
205
+ response = ChatResponse(
206
+ content=validated_out,
207
+ turn_id=assistant_turn.id,
208
+ run_id=result.run_id,
209
+ tool_calls=(),
210
+ tokens_in=result.tokens_in,
211
+ tokens_out=result.tokens_out,
212
+ cost_usd=result.cost_usd,
213
+ duration_ms=duration_ms,
214
+ finish_reason=str(result.finish_reason),
215
+ )
216
+ return response, assistant_turn
217
+
218
+ def _guard_context(self) -> dict[str, Any]:
219
+ return {
220
+ "session_id": self._session_id,
221
+ "owner": self._owner or "anonymous",
222
+ "project": "chat",
223
+ }
224
+
225
+ async def _build_user_turn(self, content: str) -> ChatTurn:
226
+ turn = ChatTurn(
227
+ id=uuid4().hex,
228
+ session_id=self._session_id,
229
+ role="user",
230
+ content=content,
231
+ )
232
+ await self._history.append(turn)
233
+ if self._on_turn is not None:
234
+ self._on_turn(turn)
235
+ return turn
236
+
237
+ async def _compose_task(self, user_turn: ChatTurn) -> str:
238
+ prior = await self._history.load(self._session_id, limit=None)
239
+ # Drop the just-appended user turn — it's added as the final
240
+ # line below.
241
+ prior_without_current = [t for t in prior if t.id != user_turn.id]
242
+ kept = await self._truncation.select(prior_without_current, user_turn.content, {})
243
+ lines: list[str] = []
244
+ prompt = self._system_prompt or ""
245
+ if prompt:
246
+ lines.append(prompt)
247
+ lines.extend(f"{t.role}: {t.content}" for t in kept)
248
+ lines.append(f"user: {user_turn.content}")
249
+ return "\n\n".join(lines)
250
+
251
+ def _extract_text(self, result: Any) -> str:
252
+ output = result.output
253
+ if isinstance(output, str):
254
+ return output
255
+ return str(output)
256
+
257
+ async def _persist_assistant(
258
+ self,
259
+ text: str,
260
+ result: Any,
261
+ duration_ms: int,
262
+ ) -> ChatTurn:
263
+ del duration_ms
264
+ turn = ChatTurn(
265
+ id=uuid4().hex,
266
+ session_id=self._session_id,
267
+ role="assistant",
268
+ content=text,
269
+ run_id=result.run_id,
270
+ tokens_in=result.tokens_in,
271
+ tokens_out=result.tokens_out,
272
+ cost_usd=result.cost_usd,
273
+ )
274
+ await self._history.append(turn)
275
+ if self._on_turn is not None:
276
+ self._on_turn(turn)
277
+ self._total_cost += result.cost_usd
278
+ self._turn_count += 1
279
+ await self._history.update_session_metadata(
280
+ self._session_id,
281
+ {"owner": self._owner, "total_cost_usd": self._total_cost},
282
+ )
283
+ return turn
284
+
285
+ def _enforce_budgets(self, turn_cost: float) -> None:
286
+ if self._per_turn_budget is not None and turn_cost > self._per_turn_budget:
287
+ raise BudgetExceeded(
288
+ f"chat turn cost ${turn_cost:.4f} exceeds per-turn budget "
289
+ f"${self._per_turn_budget:.4f}"
290
+ )
291
+ if self._per_session_budget is not None and self._total_cost > self._per_session_budget:
292
+ raise BudgetExceeded(
293
+ f"chat session total ${self._total_cost:.4f} exceeds per-session "
294
+ f"budget ${self._per_session_budget:.4f}"
295
+ )
296
+
297
+ def _strategy_overrides_stream(self) -> bool:
298
+ """True when the agent's strategy defines its own `stream()`.
299
+
300
+ Distinguishes "real per-token streaming" from the default
301
+ ABC behaviour (which just wraps `run()` + emits one `done`).
302
+ Real per-token strategies override `stream()` to yield text /
303
+ tool-call events as the LLM emits them. v0.2 falls back to
304
+ buffer-then-stream when the override isn't there so v0.1
305
+ callers get the same wire shape they had before.
306
+ """
307
+ return type(self._agent._strategy).stream is not ReasoningStrategy.stream
308
+
309
+ async def _stream_impl(
310
+ self,
311
+ message: str,
312
+ idempotency_key: str | None,
313
+ cancellation: asyncio.Event | None,
314
+ ) -> AsyncIterator[ChatChunk]:
315
+ async with self._lock:
316
+ cached = self._check_cache(idempotency_key)
317
+ if cached is not None:
318
+ async for chunk in self._chunks_for(cached):
319
+ yield chunk
320
+ return
321
+ if self._strategy_overrides_stream():
322
+ try:
323
+ async for chunk in self._stream_per_token(message, cancellation=cancellation):
324
+ yield chunk
325
+ return # noqa: TRY300 — return in try is the explicit happy-path exit
326
+ except (BudgetExceeded, GuardrailViolation, asyncio.CancelledError) as exc:
327
+ yield ChatChunk(
328
+ kind="error",
329
+ turn_id=uuid4().hex,
330
+ content={"reason": type(exc).__name__, "message": str(exc)},
331
+ )
332
+ return
333
+ try:
334
+ response, _ = await self._run_turn(message, cancellation=cancellation)
335
+ except (BudgetExceeded, GuardrailViolation, asyncio.CancelledError) as exc:
336
+ yield ChatChunk(
337
+ kind="error",
338
+ turn_id=uuid4().hex,
339
+ content={"reason": type(exc).__name__, "message": str(exc)},
340
+ )
341
+ return
342
+ self._stash_cache(idempotency_key, response)
343
+ async for chunk in self._chunks_for(response):
344
+ yield chunk
345
+
346
+ async def _stream_per_token( # noqa: PLR0912
347
+ self,
348
+ message: str,
349
+ *,
350
+ cancellation: asyncio.Event | None,
351
+ ) -> AsyncIterator[ChatChunk]:
352
+ """Drive the agent via `agent.stream(task)` and forward every
353
+ `StreamingEvent` as a `ChatChunk`. Persists the user + final
354
+ assistant turns and updates per-session budgets the same way
355
+ `_run_turn` does.
356
+
357
+ When ``safety_mode == "sentence-window"`` (or its current
358
+ alias ``"stream-then-redact"``), `text` events are buffered
359
+ until a sentence boundary; each completed sentence runs
360
+ through ``check_output`` before being emitted to the wire.
361
+ Non-text events (``tool_call``, ``step``, ``error``) pass
362
+ through unbuffered.
363
+ """
364
+ ctx = self._guard_context()
365
+ validated_msg = await self._agent._guardrails.check_input(message, ctx)
366
+ user_turn = await self._build_user_turn(validated_msg)
367
+ if cancellation is not None and cancellation.is_set():
368
+ raise asyncio.CancelledError("chat turn cancelled before agent.stream")
369
+ task = await self._compose_task(user_turn)
370
+ assistant_turn_id = uuid4().hex
371
+ cumulative = ""
372
+ run_summary: dict[str, Any] | None = None
373
+ start = time.monotonic()
374
+ buffered = self._safety_mode in ("sentence-window", "stream-then-redact")
375
+ window = _SentenceWindowBuffer() if buffered else None
376
+ async for event in self._agent.stream(task):
377
+ if event.kind == "done":
378
+ if isinstance(event.content, dict):
379
+ run_summary = event.content
380
+ # Don't break — let the generator yield its terminal
381
+ # event and complete naturally so `Agent.stream`'s
382
+ # `finally: reset_run(token)` fires deterministically.
383
+ continue
384
+ if event.kind == "text" and isinstance(event.content, str):
385
+ if window is not None:
386
+ for sentence in window.push(event.content):
387
+ validated = await self._agent._guardrails.check_output(sentence, ctx)
388
+ cumulative += validated
389
+ yield ChatChunk(
390
+ kind="text",
391
+ content=validated,
392
+ cumulative_text=cumulative,
393
+ turn_id=assistant_turn_id,
394
+ metadata=dict(event.metadata),
395
+ )
396
+ continue
397
+ cumulative += event.content
398
+ yield self._chunk_from_event(event, assistant_turn_id)
399
+ if window is not None:
400
+ residual = window.flush()
401
+ if residual:
402
+ validated = await self._agent._guardrails.check_output(residual, ctx)
403
+ cumulative += validated
404
+ yield ChatChunk(
405
+ kind="text",
406
+ content=validated,
407
+ cumulative_text=cumulative,
408
+ turn_id=assistant_turn_id,
409
+ metadata={},
410
+ )
411
+ duration_ms = int((time.monotonic() - start) * 1000)
412
+ if run_summary is None:
413
+ run_summary = {
414
+ "output": cumulative,
415
+ "run_id": uuid4().hex,
416
+ "cost_usd": 0.0,
417
+ "tokens_in": 0,
418
+ "tokens_out": 0,
419
+ "finish_reason": "completed",
420
+ "duration_ms": duration_ms,
421
+ }
422
+ if buffered:
423
+ # Each sentence already passed through `check_output`;
424
+ # `cumulative` is the validated text. Skip the terminal
425
+ # double-validation.
426
+ validated_out = cumulative
427
+ else:
428
+ final_text = (
429
+ str(run_summary.get("output", cumulative))
430
+ if isinstance(run_summary.get("output"), str)
431
+ else cumulative
432
+ )
433
+ validated_out = await self._agent._guardrails.check_output(final_text, ctx)
434
+ assistant_turn = ChatTurn(
435
+ id=assistant_turn_id,
436
+ session_id=self._session_id,
437
+ role="assistant",
438
+ content=validated_out,
439
+ run_id=str(run_summary.get("run_id", "")),
440
+ tokens_in=int(run_summary.get("tokens_in", 0) or 0),
441
+ tokens_out=int(run_summary.get("tokens_out", 0) or 0),
442
+ cost_usd=float(run_summary.get("cost_usd", 0.0) or 0.0),
443
+ )
444
+ await self._history.append(assistant_turn)
445
+ if self._on_turn is not None:
446
+ self._on_turn(assistant_turn)
447
+ self._total_cost += float(run_summary.get("cost_usd", 0.0) or 0.0)
448
+ self._turn_count += 1
449
+ await self._history.update_session_metadata(
450
+ self._session_id,
451
+ {"owner": self._owner, "total_cost_usd": self._total_cost},
452
+ )
453
+ self._enforce_budgets(float(run_summary.get("cost_usd", 0.0) or 0.0))
454
+ yield ChatChunk(
455
+ kind="done",
456
+ turn_id=assistant_turn_id,
457
+ content={
458
+ "run_id": str(run_summary.get("run_id", "")),
459
+ "cost_usd": float(run_summary.get("cost_usd", 0.0) or 0.0),
460
+ "tokens_in": int(run_summary.get("tokens_in", 0) or 0),
461
+ "tokens_out": int(run_summary.get("tokens_out", 0) or 0),
462
+ },
463
+ )
464
+
465
+ def _chunk_from_event(self, event: StreamingEvent, turn_id: str) -> ChatChunk:
466
+ return ChatChunk(
467
+ kind=event.kind,
468
+ content=event.content,
469
+ cumulative_text=event.cumulative_text,
470
+ turn_id=turn_id,
471
+ metadata=dict(event.metadata),
472
+ )
473
+
474
+ async def _chunks_for(self, response: ChatResponse) -> AsyncIterator[ChatChunk]:
475
+ cumulative = ""
476
+ for piece in segment_for_stream(response.content):
477
+ cumulative += piece
478
+ yield ChatChunk(
479
+ kind="text",
480
+ content=piece,
481
+ cumulative_text=cumulative,
482
+ turn_id=response.turn_id,
483
+ )
484
+ yield ChatChunk(
485
+ kind="done",
486
+ turn_id=response.turn_id,
487
+ content={
488
+ "run_id": response.run_id,
489
+ "cost_usd": response.cost_usd,
490
+ "tokens_in": response.tokens_in,
491
+ "tokens_out": response.tokens_out,
492
+ },
493
+ )
494
+
495
+
496
+ __all__ = ["ChatSession"]