rossum-agent 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. rossum_agent/__init__.py +9 -0
  2. rossum_agent/agent/__init__.py +32 -0
  3. rossum_agent/agent/core.py +932 -0
  4. rossum_agent/agent/memory.py +176 -0
  5. rossum_agent/agent/models.py +160 -0
  6. rossum_agent/agent/request_classifier.py +152 -0
  7. rossum_agent/agent/skills.py +132 -0
  8. rossum_agent/agent/types.py +5 -0
  9. rossum_agent/agent_logging.py +56 -0
  10. rossum_agent/api/__init__.py +1 -0
  11. rossum_agent/api/cli.py +51 -0
  12. rossum_agent/api/dependencies.py +190 -0
  13. rossum_agent/api/main.py +180 -0
  14. rossum_agent/api/models/__init__.py +1 -0
  15. rossum_agent/api/models/schemas.py +301 -0
  16. rossum_agent/api/routes/__init__.py +1 -0
  17. rossum_agent/api/routes/chats.py +95 -0
  18. rossum_agent/api/routes/files.py +113 -0
  19. rossum_agent/api/routes/health.py +44 -0
  20. rossum_agent/api/routes/messages.py +218 -0
  21. rossum_agent/api/services/__init__.py +1 -0
  22. rossum_agent/api/services/agent_service.py +451 -0
  23. rossum_agent/api/services/chat_service.py +197 -0
  24. rossum_agent/api/services/file_service.py +65 -0
  25. rossum_agent/assets/Primary_light_logo.png +0 -0
  26. rossum_agent/bedrock_client.py +64 -0
  27. rossum_agent/prompts/__init__.py +27 -0
  28. rossum_agent/prompts/base_prompt.py +80 -0
  29. rossum_agent/prompts/system_prompt.py +24 -0
  30. rossum_agent/py.typed +0 -0
  31. rossum_agent/redis_storage.py +482 -0
  32. rossum_agent/rossum_mcp_integration.py +123 -0
  33. rossum_agent/skills/hook-debugging.md +31 -0
  34. rossum_agent/skills/organization-setup.md +60 -0
  35. rossum_agent/skills/rossum-deployment.md +102 -0
  36. rossum_agent/skills/schema-patching.md +61 -0
  37. rossum_agent/skills/schema-pruning.md +23 -0
  38. rossum_agent/skills/ui-settings.md +45 -0
  39. rossum_agent/streamlit_app/__init__.py +1 -0
  40. rossum_agent/streamlit_app/app.py +646 -0
  41. rossum_agent/streamlit_app/beep_sound.py +36 -0
  42. rossum_agent/streamlit_app/cli.py +17 -0
  43. rossum_agent/streamlit_app/render_modules.py +123 -0
  44. rossum_agent/streamlit_app/response_formatting.py +305 -0
  45. rossum_agent/tools/__init__.py +214 -0
  46. rossum_agent/tools/core.py +173 -0
  47. rossum_agent/tools/deploy.py +404 -0
  48. rossum_agent/tools/dynamic_tools.py +365 -0
  49. rossum_agent/tools/file_tools.py +62 -0
  50. rossum_agent/tools/formula.py +187 -0
  51. rossum_agent/tools/skills.py +31 -0
  52. rossum_agent/tools/spawn_mcp.py +227 -0
  53. rossum_agent/tools/subagents/__init__.py +31 -0
  54. rossum_agent/tools/subagents/base.py +303 -0
  55. rossum_agent/tools/subagents/hook_debug.py +591 -0
  56. rossum_agent/tools/subagents/knowledge_base.py +305 -0
  57. rossum_agent/tools/subagents/mcp_helpers.py +47 -0
  58. rossum_agent/tools/subagents/schema_patching.py +471 -0
  59. rossum_agent/url_context.py +167 -0
  60. rossum_agent/user_detection.py +100 -0
  61. rossum_agent/utils.py +128 -0
  62. rossum_agent-1.0.0rc0.dist-info/METADATA +311 -0
  63. rossum_agent-1.0.0rc0.dist-info/RECORD +67 -0
  64. rossum_agent-1.0.0rc0.dist-info/WHEEL +5 -0
  65. rossum_agent-1.0.0rc0.dist-info/entry_points.txt +3 -0
  66. rossum_agent-1.0.0rc0.dist-info/licenses/LICENSE +21 -0
  67. rossum_agent-1.0.0rc0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,932 @@
1
+ """Core agent module implementing the RossumAgent class with Anthropic tool use API.
2
+
3
+ This module provides the main agent loop for interacting with the Rossum platform
4
+ using Claude models via AWS Bedrock and MCP tools.
5
+
6
+ Streaming Architecture & AgentStep Yield Points
7
+ ================================================
8
+
9
+ The agent streams responses via `_stream_model_response` which yields `AgentStep` objects
10
+ at multiple points to provide real-time updates to the client. The yield flow is:
11
+
12
+ _stream_model_response
13
+
14
+ ├── #5 forwards from _process_stream_events ──┬── #1 Timeout flush (buffer stale after 1.5s)
15
+ │ ├── #2 Stream end flush (final text)
16
+ │ ├── #3 Thinking tokens (chain-of-thought)
17
+ │ └── #4 Text deltas (after initial buffer)
18
+
19
+ ├── #6 Final answer (no tools, response complete)
20
+
21
+ └── #7 forwards from _execute_tools_with_progress
22
+ ├── Tool starting (which tool is about to run)
23
+ └── Sub-agent progress (from nested agent tools like debug_hook)
24
+
25
+ Key concepts:
26
+ - Initial text buffering (INITIAL_TEXT_BUFFER_DELAY=1.5s) allows determining step type
27
+ (INTERMEDIATE vs FINAL_ANSWER) before streaming to client
28
+ - After initial flush, text tokens stream immediately
29
+ - Tool execution yields progress updates for UI responsiveness
30
+ - In a single step, a thinking block is always followed by an intermediate block
31
+ (tool calls or text response)
32
+ """
33
+
34
+ from __future__ import annotations
35
+
36
+ import asyncio
37
+ import dataclasses
38
+ import json
39
+ import logging
40
+ import queue
41
+ import random
42
+ import time
43
+ from contextvars import copy_context
44
+ from functools import partial
45
+ from typing import TYPE_CHECKING
46
+
47
+ from anthropic import APIError, APITimeoutError, RateLimitError
48
+ from anthropic._types import Omit
49
+ from anthropic.types import (
50
+ ContentBlockStopEvent,
51
+ InputJSONDelta,
52
+ Message,
53
+ MessageParam,
54
+ MessageStreamEvent,
55
+ RawContentBlockDeltaEvent,
56
+ RawContentBlockStartEvent,
57
+ TextBlockParam,
58
+ TextDelta,
59
+ ThinkingBlock,
60
+ ThinkingConfigEnabledParam,
61
+ ThinkingDelta,
62
+ ToolParam,
63
+ ToolUseBlock,
64
+ )
65
+ from pydantic import BaseModel
66
+
67
+ from rossum_agent.agent.memory import AgentMemory, MemoryStep
68
+ from rossum_agent.agent.models import (
69
+ AgentConfig,
70
+ AgentStep,
71
+ StepType,
72
+ StreamDelta,
73
+ ThinkingBlockData,
74
+ ToolCall,
75
+ ToolResult,
76
+ truncate_content,
77
+ )
78
+ from rossum_agent.agent.request_classifier import RequestScope, classify_request, generate_rejection_response
79
+ from rossum_agent.api.models.schemas import TokenUsageBreakdown
80
+ from rossum_agent.bedrock_client import create_bedrock_client, get_model_id
81
+ from rossum_agent.rossum_mcp_integration import MCPConnection, mcp_tools_to_anthropic_format
82
+ from rossum_agent.tools import (
83
+ DEPLOY_TOOLS,
84
+ DISCOVERY_TOOL_NAME,
85
+ SubAgentProgress,
86
+ SubAgentTokenUsage,
87
+ execute_internal_tool,
88
+ execute_tool,
89
+ get_deploy_tool_names,
90
+ get_deploy_tools,
91
+ get_dynamic_tools,
92
+ get_internal_tool_names,
93
+ get_internal_tools,
94
+ preload_categories_for_request,
95
+ reset_dynamic_tools,
96
+ set_mcp_connection,
97
+ set_progress_callback,
98
+ set_token_callback,
99
+ )
100
+
101
+ if TYPE_CHECKING:
102
+ from collections.abc import AsyncIterator, Iterator
103
+ from typing import Literal
104
+
105
+ from anthropic import AnthropicBedrock
106
+
107
+ from rossum_agent.agent.types import UserContent
108
+
109
+ logger = logging.getLogger(__name__)
110
+
111
+
112
+ RATE_LIMIT_MAX_RETRIES = 5
113
+ RATE_LIMIT_BASE_DELAY = 2.0
114
+ RATE_LIMIT_MAX_DELAY = 60.0
115
+
116
+ # Buffer text tokens for this duration before first flush to allow time to determine
117
+ # whether this is an intermediate step (with tool calls) or final answer text.
118
+ # This delay helps correctly classify the step type before streaming to the client.
119
+ INITIAL_TEXT_BUFFER_DELAY = 1.5
120
+
121
+
122
+ def _parse_json_encoded_strings(arguments: dict) -> dict:
123
+ """Recursively parse JSON-encoded strings in tool arguments.
124
+
125
+ LLMs sometimes generate JSON-encoded strings for list/dict arguments instead of
126
+ actual lists/dicts. This function detects and parses such strings.
127
+
128
+ For example, converts:
129
+ {"fields_to_keep": "[\"a\", \"b\"]"}
130
+ To:
131
+ {"fields_to_keep": ["a", "b"]}
132
+ """
133
+ # Parameters that should remain as JSON strings (not parsed to lists/dicts)
134
+ keep_as_string = {"changes"}
135
+
136
+ result = {}
137
+ for key, value in arguments.items():
138
+ if key in keep_as_string:
139
+ result[key] = value
140
+ elif isinstance(value, str) and value.startswith(("[", "{")):
141
+ try:
142
+ parsed = json.loads(value)
143
+ if isinstance(parsed, (list, dict)):
144
+ result[key] = parsed
145
+ else:
146
+ result[key] = value
147
+ except json.JSONDecodeError:
148
+ result[key] = value
149
+ elif isinstance(value, dict):
150
+ result[key] = _parse_json_encoded_strings(value)
151
+ else:
152
+ result[key] = value
153
+ return result
154
+
155
+
156
+ @dataclasses.dataclass
157
+ class _StreamState:
158
+ """Mutable state for streaming model response.
159
+
160
+ Attributes:
161
+ first_text_token_time: Timestamp of when the first text token was received.
162
+ Used to implement initial buffering delay (see INITIAL_TEXT_BUFFER_DELAY).
163
+ initial_buffer_flushed: Whether the initial buffer has been flushed after
164
+ the delay period. Once True, text tokens are streamed immediately.
165
+ """
166
+
167
+ thinking_text: str = ""
168
+ response_text: str = ""
169
+ final_message: Message | None = None
170
+ text_buffer: list[str] = dataclasses.field(default_factory=list)
171
+ tool_calls: list[ToolCall] = dataclasses.field(default_factory=list)
172
+ pending_tools: dict[int, dict[str, str]] = dataclasses.field(default_factory=dict)
173
+ first_text_token_time: float | None = None
174
+ initial_buffer_flushed: bool = False
175
+
176
+ def _should_flush_initial_buffer(self) -> bool:
177
+ """Check if the initial buffer delay has elapsed and buffer should be flushed."""
178
+ if self.initial_buffer_flushed:
179
+ return True
180
+ if self.first_text_token_time is None:
181
+ return False
182
+ return (time.monotonic() - self.first_text_token_time) >= INITIAL_TEXT_BUFFER_DELAY
183
+
184
+ def get_step_type(self) -> StepType:
185
+ """Get the step type based on whether tool calls are pending."""
186
+ return StepType.INTERMEDIATE if self.pending_tools or self.tool_calls else StepType.FINAL_ANSWER
187
+
188
+ def flush_buffer(self, step_num: int, step_type: StepType) -> AgentStep | None:
189
+ """Flush text buffer and return AgentStep if buffer had content."""
190
+ if not self.text_buffer:
191
+ return None
192
+ buffered_text = "".join(self.text_buffer)
193
+ self.text_buffer.clear()
194
+ self.response_text += buffered_text
195
+ return AgentStep(
196
+ step_number=step_num,
197
+ thinking=self.thinking_text or None,
198
+ is_streaming=True,
199
+ text_delta=buffered_text,
200
+ accumulated_text=self.response_text,
201
+ step_type=step_type,
202
+ )
203
+
204
+ @property
205
+ def contains_thinking(self) -> bool:
206
+ return bool(self.thinking_text)
207
+
208
+
209
+ class RossumAgent:
210
+ """Claude-powered agent for Rossum document processing.
211
+
212
+ This agent uses Anthropic's tool use API to interact with the Rossum platform
213
+ via MCP tools. It maintains conversation state across multiple turns and
214
+ supports streaming responses.
215
+
216
+ Memory is stored as structured MemoryStep objects and rebuilt into messages
217
+ each call.
218
+ """
219
+
220
+ def __init__(
221
+ self,
222
+ client: AnthropicBedrock,
223
+ mcp_connection: MCPConnection,
224
+ system_prompt: str,
225
+ config: AgentConfig | None = None,
226
+ additional_tools: list[ToolParam] | None = None,
227
+ ) -> None:
228
+ self.client = client
229
+ self.mcp_connection = mcp_connection
230
+ self.system_prompt = system_prompt
231
+ self.config = config or AgentConfig()
232
+ self.additional_tools = additional_tools or []
233
+
234
+ self.memory = AgentMemory()
235
+ self._tools_cache: list[ToolParam] | None = None
236
+ self._total_input_tokens: int = 0
237
+ self._total_output_tokens: int = 0
238
+ # Token breakdown tracking
239
+ self._main_agent_input_tokens: int = 0
240
+ self._main_agent_output_tokens: int = 0
241
+ self._sub_agent_input_tokens: int = 0
242
+ self._sub_agent_output_tokens: int = 0
243
+ self._sub_agent_usage: dict[str, tuple[int, int]] = {} # tool_name -> (input, output)
244
+
245
+ @property
246
+ def messages(self) -> list[MessageParam]:
247
+ """Get the current conversation messages (rebuilt from memory)."""
248
+ return self.memory.write_to_messages()
249
+
250
+ def reset(self) -> None:
251
+ """Reset the agent's conversation state."""
252
+ self.memory.reset()
253
+ self._total_input_tokens = 0
254
+ self._total_output_tokens = 0
255
+ self._main_agent_input_tokens = 0
256
+ self._main_agent_output_tokens = 0
257
+ self._sub_agent_input_tokens = 0
258
+ self._sub_agent_output_tokens = 0
259
+ self._sub_agent_usage = {}
260
+ reset_dynamic_tools()
261
+
262
+ def _accumulate_sub_agent_tokens(self, usage: SubAgentTokenUsage) -> None:
263
+ """Accumulate token usage from a sub-agent call."""
264
+ self._total_input_tokens += usage.input_tokens
265
+ self._total_output_tokens += usage.output_tokens
266
+ self._sub_agent_input_tokens += usage.input_tokens
267
+ self._sub_agent_output_tokens += usage.output_tokens
268
+ # Track per sub-agent
269
+ prev_in, prev_out = self._sub_agent_usage.get(usage.tool_name, (0, 0))
270
+ self._sub_agent_usage[usage.tool_name] = (prev_in + usage.input_tokens, prev_out + usage.output_tokens)
271
+ logger.info(
272
+ f"Sub-agent '{usage.tool_name}' token usage (iter {usage.iteration}): "
273
+ f"in={usage.input_tokens}, out={usage.output_tokens}, "
274
+ f"cumulative total: in={self._total_input_tokens}, out={self._total_output_tokens}"
275
+ )
276
+
277
+ def get_token_usage_breakdown(self) -> TokenUsageBreakdown:
278
+ """Get token usage breakdown by agent vs sub-agents."""
279
+ return TokenUsageBreakdown.from_raw_counts(
280
+ total_input=self._total_input_tokens,
281
+ total_output=self._total_output_tokens,
282
+ main_input=self._main_agent_input_tokens,
283
+ main_output=self._main_agent_output_tokens,
284
+ sub_input=self._sub_agent_input_tokens,
285
+ sub_output=self._sub_agent_output_tokens,
286
+ sub_by_tool=self._sub_agent_usage,
287
+ )
288
+
289
+ def log_token_usage_summary(self) -> None:
290
+ """Log a human-readable token usage summary."""
291
+ breakdown = self.get_token_usage_breakdown()
292
+ logger.info("\n".join(breakdown.format_summary_lines()))
293
+
294
+ def add_user_message(self, content: UserContent) -> None:
295
+ """Add a user message to the conversation history."""
296
+ self.memory.add_task(content)
297
+
298
+ def add_assistant_message(self, content: str) -> None:
299
+ """Add an assistant message to the conversation history.
300
+
301
+ This creates a MemoryStep with text set, which ensures the
302
+ message is properly serialized when rebuilding conversation history.
303
+ For proper conversation flow with tool use, use the run() method instead.
304
+ """
305
+ step = MemoryStep(step_number=0, text=content)
306
+ self.memory.add_step(step)
307
+
308
+ async def _get_tools(self) -> list[ToolParam]:
309
+ """Get all available tools in Anthropic format.
310
+
311
+ Initially loads only the discovery tool from MCP (list_tool_categories)
312
+ plus internal tools and deploy tools. Additional MCP tools are loaded dynamically
313
+ via load_tool_category and added to the tool list through get_dynamic_tools().
314
+ """
315
+ if self._tools_cache is None:
316
+ # Load only discovery tool from MCP initially to reduce context usage
317
+ mcp_tools = await self.mcp_connection.get_tools()
318
+ discovery_tools = [t for t in mcp_tools if t.name == DISCOVERY_TOOL_NAME]
319
+ self._tools_cache = (
320
+ mcp_tools_to_anthropic_format(discovery_tools)
321
+ + get_internal_tools()
322
+ + get_deploy_tools()
323
+ + self.additional_tools
324
+ )
325
+ # Include dynamically loaded tools
326
+ return self._tools_cache + get_dynamic_tools()
327
+
328
+ def _serialize_tool_result(self, result: object) -> str:
329
+ """Serialize a tool result to a string for storage in context.
330
+
331
+ Handles pydantic models, dataclasses, dicts, lists, and other objects properly.
332
+ """
333
+ if result is None:
334
+ return "Tool executed successfully (no output)"
335
+
336
+ # Handle dataclasses (check before pydantic since pydantic models aren't dataclasses)
337
+ if dataclasses.is_dataclass(result) and not isinstance(result, type):
338
+ return json.dumps(dataclasses.asdict(result), indent=2, default=str)
339
+
340
+ # Handle lists of dataclasses
341
+ if isinstance(result, list) and result and dataclasses.is_dataclass(result[0]):
342
+ return json.dumps(
343
+ [
344
+ dataclasses.asdict(item)
345
+ for item in result
346
+ if dataclasses.is_dataclass(item) and not isinstance(item, type)
347
+ ],
348
+ indent=2,
349
+ default=str,
350
+ )
351
+
352
+ # Handle pydantic models
353
+ # Use mode='json' to ensure nested models are properly serialized to JSON-compatible dicts
354
+ if isinstance(result, BaseModel):
355
+ return json.dumps(result.model_dump(mode="json"), indent=2, default=str)
356
+
357
+ # Handle lists of pydantic models
358
+ if isinstance(result, list) and result and isinstance(result[0], BaseModel):
359
+ return json.dumps(
360
+ [item.model_dump(mode="json") for item in result if isinstance(item, BaseModel)],
361
+ indent=2,
362
+ default=str,
363
+ )
364
+
365
+ # Handle dicts and regular lists
366
+ if isinstance(result, dict | list):
367
+ return json.dumps(result, indent=2, default=str)
368
+
369
+ # Fallback to string representation
370
+ return str(result)
371
+
372
+ def _sync_stream_events(
373
+ self, model_id: str, messages: list[MessageParam], tools: list[ToolParam]
374
+ ) -> Iterator[tuple[MessageStreamEvent | None, Message | None]]:
375
+ """Synchronous generator that yields stream events and final message.
376
+
377
+ This runs in a thread pool to avoid blocking the event loop.
378
+
379
+ Yields:
380
+ Tuples of (event, None) for each stream event, then (None, final_message) at the end.
381
+ """
382
+ thinking_config: ThinkingConfigEnabledParam = {
383
+ "type": "enabled",
384
+ "budget_tokens": self.config.thinking_budget_tokens,
385
+ }
386
+ with self.client.messages.stream(
387
+ model=model_id,
388
+ max_tokens=self.config.max_output_tokens,
389
+ system=self.system_prompt,
390
+ messages=messages,
391
+ tools=tools if tools else Omit(),
392
+ thinking=thinking_config,
393
+ temperature=self.config.temperature,
394
+ ) as stream:
395
+ for event in stream:
396
+ yield (event, None)
397
+ yield (None, stream.get_final_message())
398
+
399
+ def _process_stream_event(
400
+ self,
401
+ event: MessageStreamEvent,
402
+ pending_tools: dict[int, dict[str, str]],
403
+ tool_calls: list[ToolCall],
404
+ ) -> StreamDelta | None:
405
+ """Process a single stream event.
406
+
407
+ Returns:
408
+ StreamDelta with kind="thinking" or "text", or None if no delta.
409
+ """
410
+ if isinstance(event, RawContentBlockStartEvent):
411
+ if isinstance(event.content_block, ToolUseBlock):
412
+ pending_tools[event.index] = {
413
+ "name": event.content_block.name,
414
+ "id": event.content_block.id,
415
+ "json": "",
416
+ }
417
+
418
+ elif isinstance(event, RawContentBlockDeltaEvent):
419
+ if isinstance(event.delta, ThinkingDelta):
420
+ return StreamDelta(kind="thinking", content=event.delta.thinking)
421
+ if isinstance(event.delta, TextDelta):
422
+ return StreamDelta(kind="text", content=event.delta.text)
423
+ if isinstance(event.delta, InputJSONDelta) and event.index in pending_tools:
424
+ pending_tools[event.index]["json"] += event.delta.partial_json
425
+
426
+ elif isinstance(event, ContentBlockStopEvent) and event.index in pending_tools:
427
+ tool_info = pending_tools.pop(event.index)
428
+ try:
429
+ arguments = json.loads(tool_info["json"]) if tool_info["json"] else {}
430
+ arguments = _parse_json_encoded_strings(arguments)
431
+ except json.JSONDecodeError as e:
432
+ logger.warning("Failed to decode tool arguments for %s: %s", tool_info["name"], e)
433
+ arguments = {}
434
+ tool_calls.append(ToolCall(id=tool_info["id"], name=tool_info["name"], arguments=arguments))
435
+
436
+ return None
437
+
438
+ def _extract_thinking_blocks(self, message: Message) -> list[ThinkingBlockData]:
439
+ """Extract thinking blocks from a message for preserving in conversation history."""
440
+ return [
441
+ ThinkingBlockData(thinking=block.thinking, signature=block.signature)
442
+ for block in message.content
443
+ if isinstance(block, ThinkingBlock)
444
+ ]
445
+
446
+ def _handle_text_delta(
447
+ self, step_num: int, content: str, delta_kind: Literal["thinking", "text"], state: _StreamState
448
+ ) -> AgentStep | None:
449
+ """Handle a text delta, buffering or flushing as appropriate."""
450
+ if state.first_text_token_time is None:
451
+ state.first_text_token_time = time.monotonic()
452
+ else:
453
+ if time.monotonic() - state.first_text_token_time > INITIAL_TEXT_BUFFER_DELAY:
454
+ state.initial_buffer_flushed = True
455
+
456
+ state.text_buffer.append(content)
457
+
458
+ if state.initial_buffer_flushed:
459
+ step_type = (
460
+ StepType.INTERMEDIATE if state.contains_thinking and delta_kind == "text" else state.get_step_type()
461
+ )
462
+ return state.flush_buffer(step_num, step_type)
463
+ if state.pending_tools or state.tool_calls:
464
+ state.initial_buffer_flushed = True
465
+ return state.flush_buffer(step_num, StepType.INTERMEDIATE)
466
+ return None
467
+
468
+ async def _process_stream_events(
469
+ self,
470
+ step_num: int,
471
+ event_queue: queue.Queue[tuple[MessageStreamEvent | None, Message | None] | None],
472
+ state: _StreamState,
473
+ ) -> AsyncIterator[AgentStep]:
474
+ """Process stream events and yield AgentSteps.
475
+
476
+ Text tokens are buffered for INITIAL_TEXT_BUFFER_DELAY seconds after the first
477
+ text token is received. This allows time to determine whether the response will
478
+ include tool calls (intermediate step) or is a final answer, enabling correct
479
+ step type classification before streaming to the client.
480
+
481
+ After the initial buffer is flushed, subsequent text tokens are streamed immediately.
482
+ """
483
+ while True:
484
+ try:
485
+ item = await asyncio.to_thread(event_queue.get, timeout=INITIAL_TEXT_BUFFER_DELAY)
486
+ except queue.Empty:
487
+ # Yield #1: Timeout-based flush of initial text buffer (ensures responsiveness during model pauses)
488
+ if (
489
+ state.text_buffer
490
+ and state._should_flush_initial_buffer()
491
+ and (step := state.flush_buffer(step_num, state.get_step_type()))
492
+ ):
493
+ state.initial_buffer_flushed = True
494
+ yield step
495
+ continue
496
+
497
+ if item is None:
498
+ # Yield #2: Stream ended - flush any remaining buffered text
499
+ if step := state.flush_buffer(step_num, state.get_step_type()):
500
+ yield step
501
+ break
502
+
503
+ event, final_msg = item
504
+ if final_msg is not None:
505
+ state.final_message = final_msg
506
+ continue
507
+
508
+ if event is None:
509
+ continue
510
+
511
+ delta = self._process_stream_event(event, state.pending_tools, state.tool_calls)
512
+ if not delta:
513
+ continue
514
+
515
+ if delta.kind == "thinking":
516
+ state.thinking_text += delta.content
517
+ # Yield #3: Streaming thinking tokens (extended thinking / chain-of-thought)
518
+ yield AgentStep(
519
+ step_number=step_num,
520
+ thinking=state.thinking_text,
521
+ is_streaming=True,
522
+ step_type=StepType.THINKING,
523
+ )
524
+ if state.first_text_token_time is None:
525
+ state.first_text_token_time = time.monotonic()
526
+ continue
527
+
528
+ # Yield #4: Text delta - immediate flush after initial buffer period or when tool calls detected
529
+ if step := self._handle_text_delta(step_num, delta.content, delta.kind, state):
530
+ yield step
531
+
532
+ async def _stream_model_response(self, step_num: int) -> AsyncIterator[AgentStep]:
533
+ """Stream model response, yielding partial steps as thinking streams in.
534
+
535
+ Extended thinking separates the model's internal reasoning (thinking blocks)
536
+ from its final response (text blocks). This allows distinguishing between
537
+ the chain-of-thought process and the actual answer.
538
+
539
+ Yields:
540
+ AgentStep objects - partial steps while streaming, then final step with tool results.
541
+ """
542
+ messages = self.memory.write_to_messages()
543
+ tools = await self._get_tools()
544
+ model_id = get_model_id()
545
+ state = _StreamState()
546
+
547
+ event_queue: queue.Queue[tuple[MessageStreamEvent | None, Message | None] | None] = queue.Queue()
548
+
549
+ def producer() -> None:
550
+ for item in self._sync_stream_events(model_id, messages, tools):
551
+ event_queue.put(item)
552
+ event_queue.put(None)
553
+
554
+ ctx = copy_context()
555
+ producer_task = asyncio.get_event_loop().run_in_executor(None, partial(ctx.run, producer))
556
+
557
+ # Yield #5: Forward all streaming steps from _process_stream_events (yields #1-4)
558
+ async for step in self._process_stream_events(step_num, event_queue, state):
559
+ yield step
560
+
561
+ await producer_task
562
+
563
+ if state.final_message is None:
564
+ raise RuntimeError("Stream ended without final message")
565
+
566
+ thinking_blocks = self._extract_thinking_blocks(state.final_message)
567
+ input_tokens = state.final_message.usage.input_tokens
568
+ output_tokens = state.final_message.usage.output_tokens
569
+ self._total_input_tokens += input_tokens
570
+ self._total_output_tokens += output_tokens
571
+ self._main_agent_input_tokens += input_tokens
572
+ self._main_agent_output_tokens += output_tokens
573
+ logger.info(
574
+ f"Step {step_num}: input_tokens={input_tokens}, output_tokens={output_tokens}, "
575
+ f"total_input={self._total_input_tokens}, total_output={self._total_output_tokens}"
576
+ )
577
+
578
+ step = AgentStep(
579
+ step_number=step_num,
580
+ thinking=state.thinking_text if state.thinking_text else None,
581
+ tool_calls=state.tool_calls,
582
+ is_streaming=False,
583
+ input_tokens=input_tokens,
584
+ output_tokens=output_tokens,
585
+ step_type=StepType.FINAL_ANSWER if not state.tool_calls else StepType.INTERMEDIATE,
586
+ )
587
+
588
+ if not state.tool_calls:
589
+ step.final_answer = state.response_text or None
590
+ step.is_final = True
591
+ memory_step = MemoryStep(
592
+ step_number=step_num,
593
+ text=state.response_text if state.response_text else None,
594
+ input_tokens=input_tokens,
595
+ output_tokens=output_tokens,
596
+ )
597
+ self.memory.add_step(memory_step)
598
+ # Yield #6: Final answer step (no tool calls, response complete)
599
+ yield step
600
+ return
601
+
602
+ # Yield #7: Forward tool execution progress steps from _execute_tools_with_progress
603
+ async for step_or_result in self._execute_tools_with_progress(
604
+ step_num, state.response_text, state.tool_calls, step, input_tokens, output_tokens, thinking_blocks
605
+ ):
606
+ yield step_or_result
607
+
608
+ async def _execute_tools_with_progress(
609
+ self,
610
+ step_num: int,
611
+ thinking_text: str,
612
+ tool_calls: list[ToolCall],
613
+ step: AgentStep,
614
+ input_tokens: int,
615
+ output_tokens: int,
616
+ thinking_blocks: list[ThinkingBlockData] | None = None,
617
+ ) -> AsyncIterator[AgentStep]:
618
+ """Execute tools in parallel and yield progress updates."""
619
+ memory_step = MemoryStep(
620
+ step_number=step_num,
621
+ text=thinking_text or None,
622
+ tool_calls=tool_calls,
623
+ thinking_blocks=thinking_blocks or [],
624
+ input_tokens=input_tokens,
625
+ output_tokens=output_tokens,
626
+ )
627
+
628
+ total_tools = len(tool_calls)
629
+
630
+ yield AgentStep(
631
+ step_number=step_num,
632
+ thinking=thinking_text or None,
633
+ tool_calls=tool_calls,
634
+ is_streaming=True,
635
+ current_tool=None,
636
+ tool_progress=(0, total_tools),
637
+ step_type=StepType.INTERMEDIATE,
638
+ )
639
+
640
+ progress_queue: asyncio.Queue[AgentStep] = asyncio.Queue()
641
+ results_by_id: dict[str, ToolResult] = {}
642
+
643
+ async def execute_single_tool(tool_call: ToolCall, idx: int) -> None:
644
+ tool_progress = (idx, total_tools)
645
+ async for progress_or_result in self._execute_tool_with_progress(
646
+ tool_call, step_num, tool_calls, tool_progress
647
+ ):
648
+ if isinstance(progress_or_result, AgentStep):
649
+ await progress_queue.put(progress_or_result)
650
+ elif isinstance(progress_or_result, ToolResult):
651
+ results_by_id[tool_call.id] = progress_or_result
652
+
653
+ tasks = [
654
+ asyncio.create_task(execute_single_tool(tool_call, idx)) for idx, tool_call in enumerate(tool_calls, 1)
655
+ ]
656
+
657
+ pending = set(tasks)
658
+ while pending:
659
+ _done, pending = await asyncio.wait(pending, timeout=0.05, return_when=asyncio.FIRST_COMPLETED)
660
+
661
+ while not progress_queue.empty():
662
+ yield progress_queue.get_nowait()
663
+
664
+ while not progress_queue.empty():
665
+ yield progress_queue.get_nowait()
666
+
667
+ tool_results = [results_by_id[tc.id] for tc in tool_calls if tc.id in results_by_id]
668
+
669
+ step.tool_results = tool_results
670
+ memory_step.tool_results = tool_results
671
+
672
+ self.memory.add_step(memory_step)
673
+
674
+ yield step
675
+
676
+ def _drain_token_queue(self, token_queue: queue.Queue[SubAgentTokenUsage]) -> None:
677
+ """Drain all pending token usage from the queue."""
678
+ while True:
679
+ try:
680
+ usage = token_queue.get_nowait()
681
+ self._accumulate_sub_agent_tokens(usage)
682
+ except queue.Empty:
683
+ break
684
+
685
+ async def _execute_tool_with_progress(
686
+ self, tool_call: ToolCall, step_num: int, tool_calls: list[ToolCall], tool_progress: tuple[int, int]
687
+ ) -> AsyncIterator[AgentStep | ToolResult]:
688
+ """Execute a tool and yield progress updates for sub-agents.
689
+
690
+ For tools with sub-agents (like debug_hook), this yields AgentStep updates
691
+ with sub_agent_progress. Always yields the final ToolResult.
692
+ """
693
+ progress_queue: queue.Queue[SubAgentProgress] = queue.Queue()
694
+ token_queue: queue.Queue[SubAgentTokenUsage] = queue.Queue()
695
+
696
+ def progress_callback(progress: SubAgentProgress) -> None:
697
+ progress_queue.put(progress)
698
+
699
+ def token_callback(usage: SubAgentTokenUsage) -> None:
700
+ token_queue.put(usage)
701
+
702
+ try:
703
+ if tool_call.name in get_internal_tool_names():
704
+ set_progress_callback(progress_callback)
705
+ set_token_callback(token_callback)
706
+
707
+ loop = asyncio.get_event_loop()
708
+ ctx = copy_context()
709
+ future = loop.run_in_executor(
710
+ None, partial(ctx.run, execute_internal_tool, tool_call.name, tool_call.arguments)
711
+ )
712
+
713
+ while not future.done():
714
+ try:
715
+ progress = progress_queue.get_nowait()
716
+ yield AgentStep(
717
+ step_number=step_num,
718
+ tool_calls=tool_calls,
719
+ is_streaming=True,
720
+ current_tool=tool_call.name,
721
+ tool_progress=tool_progress,
722
+ sub_agent_progress=progress,
723
+ step_type=StepType.INTERMEDIATE,
724
+ )
725
+ except queue.Empty:
726
+ pass
727
+
728
+ self._drain_token_queue(token_queue)
729
+ await asyncio.sleep(0.1)
730
+
731
+ self._drain_token_queue(token_queue)
732
+
733
+ result = future.result()
734
+ content = str(result)
735
+ set_progress_callback(None)
736
+ set_token_callback(None)
737
+ elif tool_call.name in get_deploy_tool_names():
738
+ loop = asyncio.get_event_loop()
739
+ ctx = copy_context()
740
+ future = loop.run_in_executor(
741
+ None, partial(ctx.run, execute_tool, tool_call.name, tool_call.arguments, DEPLOY_TOOLS)
742
+ )
743
+ result = await future
744
+ content = str(result)
745
+ else:
746
+ result = await self.mcp_connection.call_tool(tool_call.name, tool_call.arguments)
747
+ content = self._serialize_tool_result(result)
748
+
749
+ content = truncate_content(content)
750
+ yield ToolResult(tool_call_id=tool_call.id, name=tool_call.name, content=content)
751
+
752
+ except Exception as e:
753
+ set_progress_callback(None)
754
+ set_token_callback(None)
755
+ error_msg = f"Tool {tool_call.name} failed: {e}"
756
+ logger.warning(f"Tool {tool_call.name} failed: {e}", exc_info=True)
757
+ yield ToolResult(tool_call_id=tool_call.id, name=tool_call.name, content=error_msg, is_error=True)
758
+
759
+ def _extract_text_from_prompt(self, prompt: UserContent) -> str:
760
+ """Extract text content from a user prompt for classification."""
761
+ if isinstance(prompt, str):
762
+ return prompt
763
+ text_parts: list[str] = []
764
+ for block in prompt:
765
+ if block.get("type") == "text":
766
+ text = block.get("text")
767
+ if isinstance(text, str):
768
+ text_parts.append(text)
769
+ return " ".join(text_parts)
770
+
771
+ def _check_request_scope(self, prompt: UserContent) -> AgentStep | None:
772
+ """Check if request is in scope, return rejection step if out of scope."""
773
+ text = self._extract_text_from_prompt(prompt)
774
+ result = classify_request(self.client, text)
775
+ self._total_input_tokens += result.input_tokens
776
+ self._total_output_tokens += result.output_tokens
777
+ self._main_agent_input_tokens += result.input_tokens
778
+ self._main_agent_output_tokens += result.output_tokens
779
+ if result.scope == RequestScope.OUT_OF_SCOPE:
780
+ rejection = generate_rejection_response(self.client, text)
781
+ total_input = result.input_tokens + rejection.input_tokens
782
+ total_output = result.output_tokens + rejection.output_tokens
783
+ self._total_input_tokens += rejection.input_tokens
784
+ self._total_output_tokens += rejection.output_tokens
785
+ self._main_agent_input_tokens += rejection.input_tokens
786
+ self._main_agent_output_tokens += rejection.output_tokens
787
+ return AgentStep(
788
+ step_number=1,
789
+ final_answer=rejection.response,
790
+ is_final=True,
791
+ input_tokens=total_input,
792
+ output_tokens=total_output,
793
+ step_type=StepType.FINAL_ANSWER,
794
+ )
795
+ return None
796
+
797
+ def _inject_preload_info(self, prompt: UserContent, preload_result: str) -> UserContent:
798
+ """Inject preload result info into the user prompt."""
799
+ suffix = (
800
+ f"\n\n[System: {preload_result}. Use these tools directly without calling list_tool_categories first.]"
801
+ )
802
+ if isinstance(prompt, str):
803
+ return prompt + suffix
804
+ system_block: TextBlockParam = {"type": "text", "text": suffix}
805
+ return [*prompt, system_block]
806
+
807
+ def _calculate_rate_limit_delay(self, retries: int) -> float:
808
+ """Calculate exponential backoff delay with jitter for rate limiting."""
809
+ delay = min(RATE_LIMIT_BASE_DELAY * (2 ** (retries - 1)), RATE_LIMIT_MAX_DELAY)
810
+ jitter = random.uniform(0, delay * 0.1)
811
+ return delay + jitter
812
+
813
+ async def run(self, prompt: UserContent) -> AsyncIterator[AgentStep]:
814
+ """Run the agent with the given prompt, yielding steps.
815
+
816
+ This method implements the main agent loop, calling the model,
817
+ executing tools, and continuing until the model produces a final
818
+ answer or the maximum number of steps is reached.
819
+
820
+ Rate limiting is handled with exponential backoff and jitter.
821
+ """
822
+ if rejection := self._check_request_scope(prompt):
823
+ yield rejection
824
+ return
825
+
826
+ loop = asyncio.get_event_loop()
827
+ set_mcp_connection(self.mcp_connection, loop)
828
+
829
+ # Pre-load tool categories based on keywords in the user's request
830
+ # Run in thread pool to avoid blocking the event loop (preload uses sync MCP calls)
831
+ request_text = self._extract_text_from_prompt(prompt)
832
+ ctx = copy_context()
833
+ preload_result = await loop.run_in_executor(
834
+ None, partial(ctx.run, preload_categories_for_request, request_text)
835
+ )
836
+
837
+ # Inject pre-load info into the task so agent knows what tools are available
838
+ if preload_result:
839
+ prompt = self._inject_preload_info(prompt, preload_result)
840
+
841
+ self.memory.add_task(prompt)
842
+
843
+ for step_num in range(1, self.config.max_steps + 1):
844
+ rate_limit_retries = 0
845
+
846
+ # Throttle requests to avoid rate limiting (skip delay on first step)
847
+ if step_num > 1:
848
+ await asyncio.sleep(self.config.request_delay)
849
+
850
+ while True:
851
+ try:
852
+ final_step: AgentStep | None = None
853
+ async for step in self._stream_model_response(step_num):
854
+ yield step
855
+ if not step.is_streaming:
856
+ final_step = step
857
+
858
+ if final_step and final_step.is_final:
859
+ return
860
+
861
+ break
862
+
863
+ except RateLimitError as e:
864
+ rate_limit_retries += 1
865
+ if rate_limit_retries > RATE_LIMIT_MAX_RETRIES:
866
+ logger.error(f"Rate limit retries exhausted at step {step_num}: {e}")
867
+ yield AgentStep(
868
+ step_number=step_num,
869
+ error=f"Rate limit exceeded after {RATE_LIMIT_MAX_RETRIES} retries. Please try again later.",
870
+ is_final=True,
871
+ step_type=StepType.FINAL_ANSWER,
872
+ )
873
+ return
874
+
875
+ wait_time = self._calculate_rate_limit_delay(rate_limit_retries)
876
+ logger.warning(
877
+ f"Rate limit hit at step {step_num} (attempt {rate_limit_retries}/{RATE_LIMIT_MAX_RETRIES}), "
878
+ f"retrying in {wait_time:.1f}s: {e}"
879
+ )
880
+ yield AgentStep(
881
+ step_number=step_num,
882
+ thinking=f"⏳ Rate limited, waiting {wait_time:.1f}s before retry ({rate_limit_retries}/{RATE_LIMIT_MAX_RETRIES})...",
883
+ is_streaming=True,
884
+ step_type=StepType.INTERMEDIATE,
885
+ )
886
+ await asyncio.sleep(wait_time)
887
+
888
+ except APIError as e:
889
+ is_timeout = isinstance(e, APITimeoutError)
890
+ log_fn = logger.warning if is_timeout else logger.error
891
+ log_fn(f"API {'timeout' if is_timeout else 'error'} at step {step_num}: {e}")
892
+ error_msg = (
893
+ f"Request timed out. Please try again. Details: {e}"
894
+ if is_timeout
895
+ else f"API error occurred: {e}"
896
+ )
897
+ yield AgentStep(
898
+ step_number=step_num,
899
+ error=error_msg,
900
+ is_final=True,
901
+ step_type=StepType.FINAL_ANSWER,
902
+ )
903
+ return
904
+
905
+ else:
906
+ yield AgentStep(
907
+ step_number=self.config.max_steps,
908
+ error=f"Maximum steps ({self.config.max_steps}) reached without final answer.",
909
+ is_final=True,
910
+ step_type=StepType.FINAL_ANSWER,
911
+ )
912
+
913
+
914
+ async def create_agent(
915
+ mcp_connection: MCPConnection,
916
+ system_prompt: str,
917
+ config: AgentConfig | None = None,
918
+ additional_tools: list[ToolParam] | None = None,
919
+ ) -> RossumAgent:
920
+ """Create and configure a RossumAgent instance.
921
+
922
+ This is a convenience factory function that creates the Bedrock client
923
+ and initializes the agent with the provided configuration.
924
+ """
925
+ client = create_bedrock_client()
926
+ return RossumAgent(
927
+ client=client,
928
+ mcp_connection=mcp_connection,
929
+ system_prompt=system_prompt,
930
+ config=config,
931
+ additional_tools=additional_tools,
932
+ )