kader 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kader/agent/base.py ADDED
@@ -0,0 +1,927 @@
1
+ """
2
+ Base Agent Implementation.
3
+
4
+ Defines the BaseAgent class which serves as the foundation for creating specific agents
5
+ with tools, memory, and LLM provider integration.
6
+ """
7
+
8
+ from pathlib import Path
9
+ from typing import AsyncIterator, Iterator, Optional, Union
10
+
11
+ import yaml
12
+ from tenacity import RetryError, stop_after_attempt, wait_exponential
13
+
14
+ from kader.memory import (
15
+ ConversationManager,
16
+ FileSessionManager,
17
+ SlidingWindowConversationManager,
18
+ )
19
+ from kader.prompts.base import PromptBase
20
+ from kader.providers.base import (
21
+ BaseLLMProvider,
22
+ LLMResponse,
23
+ Message,
24
+ ModelConfig,
25
+ StreamChunk,
26
+ )
27
+ from kader.providers.ollama import OllamaProvider
28
+ from kader.tools import BaseTool, ToolRegistry
29
+
30
+ from .logger import agent_logger
31
+
32
+
33
+ class BaseAgent:
34
+ """
35
+ Base class for Agents.
36
+
37
+ Combines tools, memory, and an LLM provider to perform tasks.
38
+ Supports synchronous and asynchronous invocation and streaming.
39
+ Includes built-in retry logic using tenacity.
40
+ Supports session persistence via FileSessionManager.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ name: str,
46
+ system_prompt: Union[str, PromptBase],
47
+ tools: Union[list[BaseTool], ToolRegistry, None] = None,
48
+ provider: Optional[BaseLLMProvider] = None,
49
+ memory: Optional[ConversationManager] = None,
50
+ retry_attempts: int = 3,
51
+ model_name: str = "qwen3-coder:480b-cloud",
52
+ session_id: Optional[str] = None,
53
+ use_persistence: bool = False,
54
+ interrupt_before_tool: bool = True,
55
+ tool_confirmation_callback: Optional[callable] = None,
56
+ ) -> None:
57
+ """
58
+ Initialize the Base Agent.
59
+
60
+ Args:
61
+ name: Name of the agent.
62
+ system_prompt: The system prompt definition.
63
+ tools: List of tools or a ToolRegistry.
64
+ provider: LLM provider instance. If None, uses OllamaProvider.
65
+ memory: Conversation/Memory manager. If None, uses SlidingWindowConversationManager.
66
+ retry_attempts: Number of retry attempts for LLM calls (default: 3).
67
+ model_name: Default model name if creating a default Ollama provider.
68
+ session_id: Optional session ID to load/resume.
69
+ use_persistence: If True, enables session persistence (auto-enabled if session_id provided).
70
+ interrupt_before_tool: If True, pauses and asks for user confirmation before executing tools.
71
+ tool_confirmation_callback: Optional callback function for tool confirmation.
72
+ Signature: (message: str) -> tuple[bool, Optional[str]]
73
+ Returns (should_execute, user_elaboration_if_declined).
74
+ """
75
+ self.name = name
76
+ self.system_prompt = system_prompt
77
+ self.retry_attempts = retry_attempts
78
+ self.interrupt_before_tool = interrupt_before_tool
79
+ self.tool_confirmation_callback = tool_confirmation_callback
80
+
81
+ # Persistence Configuration
82
+ self.session_id = session_id
83
+ self.use_persistence = use_persistence or (session_id is not None)
84
+ self.session_manager = FileSessionManager() if self.use_persistence else None
85
+
86
+ # Initialize Logger if agent uses persistence (logs only if there's a session)
87
+ self.logger_id = None
88
+ if self.use_persistence:
89
+ # Only create logger if we have a session_id or if the session manager will create one
90
+ session_id_for_logger = self.session_id
91
+ if not session_id_for_logger and self.session_manager:
92
+ # If no session_id yet but persistence is enabled, we'll get one during _load_session
93
+ pass # We'll set up the logger in _load_session if needed
94
+ if session_id_for_logger:
95
+ self.logger_id = agent_logger.setup_logger(
96
+ self.name, session_id_for_logger
97
+ )
98
+
99
+ # Initialize Provider
100
+ if provider:
101
+ self.provider = provider
102
+ else:
103
+ self.provider = OllamaProvider(model=model_name)
104
+
105
+ # Initialize Memory
106
+ if memory:
107
+ self.memory = memory
108
+ else:
109
+ self.memory = SlidingWindowConversationManager()
110
+
111
+ # Initialize Tools
112
+ self._tool_registry = ToolRegistry()
113
+ if tools:
114
+ if isinstance(tools, ToolRegistry):
115
+ self._tool_registry = tools
116
+ elif isinstance(tools, list):
117
+ for tool in tools:
118
+ self._tool_registry.register(tool)
119
+
120
+ if self.use_persistence:
121
+ self._load_session()
122
+
123
+ # Propagate session to tools
124
+ self._propagate_session_to_tools()
125
+
126
+ # Update config with tools if provider supports it
127
+ self._update_provider_tools()
128
+
129
+ def _load_session(self) -> None:
130
+ """Load conversation history from session storage."""
131
+ if not self.session_manager:
132
+ return
133
+
134
+ if not self.session_id:
135
+ session = self.session_manager.create_session(self.name)
136
+ self.session_id = session.session_id
137
+
138
+ # Initialize logger if we now have a session_id and logging hasn't been set up yet
139
+ if self.use_persistence and not self.logger_id and self.session_id:
140
+ self.logger_id = agent_logger.setup_logger(self.name, self.session_id)
141
+
142
+ # Propagate session to tools
143
+ self._propagate_session_to_tools()
144
+
145
+ # Load conversation history
146
+ try:
147
+ # We don't check if session exists first because load_conversation
148
+ # handles missing sessions by returning empty list (usually)
149
+ # or we catch the error. FileSessionManager.load_conversation returns list[dict].
150
+ history = self.session_manager.load_conversation(self.session_id)
151
+ if history:
152
+ # Add loaded messages to memory
153
+ # ConversationManager supports adding dicts directly
154
+ self.memory.add_messages(history)
155
+ except Exception:
156
+ # If session doesn't exist or error, we start fresh (or could log warning)
157
+ # For now, we silently proceed with empty memory
158
+ pass
159
+
160
+ def _propagate_session_to_tools(self) -> None:
161
+ """Propagate current session ID to all registered tools."""
162
+ if not self.session_id:
163
+ return
164
+
165
+ for tool in self._tool_registry.tools:
166
+ if hasattr(tool, "set_session_id"):
167
+ tool.set_session_id(self.session_id)
168
+
169
+ def _save_session(self) -> None:
170
+ """Save current conversation history to session storage."""
171
+ if not self.session_manager or not self.session_id:
172
+ return
173
+
174
+ try:
175
+ # Get all messages from memory
176
+ # Convert ConversationMessage to dict (using .message property)
177
+ messages = [msg.message for msg in self.memory.get_messages()]
178
+ self.session_manager.save_conversation(self.session_id, messages)
179
+ except Exception:
180
+ # Log error or handle silently? Best not to crash main flow on save failure
181
+ pass
182
+
183
+ @property
184
+ def tools_map(self) -> dict[str, BaseTool]:
185
+ """
186
+ Get a dictionary mapping tool names to tool instances.
187
+
188
+ Returns:
189
+ Dictionary of {tool_name: tool_instance}
190
+ """
191
+ # Access private attribute of ToolRegistry if needed, or iterate
192
+ # ToolRegistry has .tools property which returns a list
193
+ return {tool.name: tool for tool in self._tool_registry.tools}
194
+
195
+ def _update_provider_tools(self) -> None:
196
+ """Update the provider's default config with registered tools."""
197
+ if not self._tool_registry.tools:
198
+ return
199
+
200
+ # We need to update the default config of the provider to include these tools
201
+ # Since we can't easily modify the internal default_config of the provider cleanly
202
+ # from here without accessing protected members, strict encapsulation might prevent this.
203
+ # However, for this implementation, we will pass tools during invoke if they exist.
204
+ pass
205
+
206
+ def _get_run_config(self, config: Optional[ModelConfig] = None) -> ModelConfig:
207
+ """Prepare execution config with tools."""
208
+ base_config = config or ModelConfig()
209
+
210
+ # If tools are available and not explicitly disabled or overridden
211
+ if self._tool_registry.tools and not base_config.tools:
212
+ # Detect provider type to format tools correctly
213
+ # Defaulting to 'openai' format as it's the de-facto standard
214
+ provider_type = "openai"
215
+ if isinstance(self.provider, OllamaProvider):
216
+ provider_type = "ollama"
217
+
218
+ base_config = ModelConfig(
219
+ temperature=base_config.temperature,
220
+ max_tokens=base_config.max_tokens,
221
+ top_p=base_config.top_p,
222
+ top_k=base_config.top_k,
223
+ frequency_penalty=base_config.frequency_penalty,
224
+ stop_sequences=base_config.stop_sequences,
225
+ stream=base_config.stream,
226
+ tools=self._tool_registry.to_provider_format(provider_type),
227
+ tool_choice=base_config.tool_choice,
228
+ extra=base_config.extra,
229
+ )
230
+
231
+ return base_config
232
+
233
+ def _prepare_messages(
234
+ self, messages: Union[str, list[Message], list[dict]]
235
+ ) -> list[Message]:
236
+ """Prepare messages adding system prompt and history."""
237
+ # Normalize input to list of Message objects
238
+ input_msgs: list[Message] = []
239
+ if isinstance(messages, str):
240
+ input_msgs = [Message.user(messages)]
241
+ elif isinstance(messages, list):
242
+ if not messages:
243
+ pass
244
+ elif isinstance(messages[0], dict):
245
+ # Convert dicts to Messages
246
+ pass # simplified for now, assuming user passes Message objects or string
247
+ # But we should handle it better
248
+ input_msgs = [
249
+ Message(**msg) if isinstance(msg, dict) else msg for msg in messages
250
+ ]
251
+ else:
252
+ input_msgs = messages # Assuming list[Message]
253
+
254
+ # Add to memory
255
+ for msg in input_msgs:
256
+ self.memory.add_message(msg)
257
+
258
+ # Retrieve context (system prompt + history)
259
+ # 1. Start with System Prompt
260
+ if isinstance(self.system_prompt, PromptBase):
261
+ sys_prompt_content = self.system_prompt.resolve_prompt()
262
+ else:
263
+ sys_prompt_content = str(self.system_prompt)
264
+
265
+ final_messages = [Message.system(sys_prompt_content)]
266
+
267
+ # 2. Get history from memory (windowed)
268
+ # memory.apply_window() returns list[dict], need to convert back to Message
269
+ history_dicts = self.memory.apply_window()
270
+
271
+ # We need to act smart here. invoke/stream usually take the *new* messages
272
+ # plus history. Memory managers usually store everything.
273
+ # If we added input_msgs to memory, apply_window should return them too if relevant.
274
+ # So we just use what Memory gives us.
275
+
276
+ for msg_dict in history_dicts:
277
+ # Basic conversion from dict back to Message
278
+ # Note: conversation.py Message support might be limited to dicts
279
+ msg = Message(
280
+ role=msg_dict.get("role"),
281
+ content=msg_dict.get("content"),
282
+ name=msg_dict.get("name"),
283
+ tool_call_id=msg_dict.get("tool_call_id"),
284
+ tool_calls=msg_dict.get("tool_calls"),
285
+ )
286
+ final_messages.append(msg)
287
+
288
+ return final_messages
289
+
290
+ def _format_tool_call_for_display(self, tool_call_dict: dict) -> str:
291
+ """
292
+ Format a tool call for display to the user.
293
+
294
+ Args:
295
+ tool_call_dict: The tool call dictionary from LLM response.
296
+
297
+ Returns:
298
+ The tool's interruption message.
299
+ """
300
+ import json
301
+
302
+ fn_info = tool_call_dict.get("function", {})
303
+ if not fn_info and "name" in tool_call_dict:
304
+ fn_info = tool_call_dict
305
+
306
+ tool_name = fn_info.get("name", "unknown")
307
+ arguments = fn_info.get("arguments", {})
308
+
309
+ # Parse arguments if string
310
+ if isinstance(arguments, str):
311
+ try:
312
+ arguments = json.loads(arguments)
313
+ except json.JSONDecodeError:
314
+ pass
315
+
316
+ # Get the tool's interruption message if available
317
+ tool = self._tool_registry.get(tool_name)
318
+ if tool and isinstance(arguments, dict):
319
+ return tool.get_interruption_message(**arguments)
320
+
321
+ # Fallback for unknown tools
322
+ return f"execute {tool_name}"
323
+
324
+ def _confirm_tool_execution(
325
+ self, tool_call_dict: dict, llm_content: Optional[str] = None
326
+ ) -> tuple[bool, Optional[str]]:
327
+ """
328
+ Ask user for confirmation before executing a tool.
329
+
330
+ Args:
331
+ tool_call_dict: The tool call dictionary from LLM response.
332
+
333
+ Returns:
334
+ Tuple of (should_execute: bool, user_input: Optional[str]).
335
+ If should_execute is False, user_input contains additional context.
336
+ """
337
+ display_str = self._format_tool_call_for_display(tool_call_dict)
338
+
339
+ if llm_content and len(llm_content) > 0:
340
+ display_str = f"{llm_content}\n\n{display_str}"
341
+
342
+ # Use callback if provided (e.g., for GUI/TUI)
343
+ if self.tool_confirmation_callback:
344
+ return self.tool_confirmation_callback(display_str)
345
+
346
+ # Default: use console input
347
+ print(display_str)
348
+
349
+ while True:
350
+ user_input = input("\nExecute this tool? (yes/no): ").strip().lower()
351
+
352
+ if user_input in ("yes", "y"):
353
+ return True, None
354
+ elif user_input in ("no", "n"):
355
+ elaboration = input(
356
+ "Please provide more context or instructions: "
357
+ ).strip()
358
+ return False, elaboration if elaboration else None
359
+ else:
360
+ print("Please enter 'yes' or 'no'.")
361
+
362
+ async def _aconfirm_tool_execution(
363
+ self, tool_call_dict: dict, llm_content: Optional[str] = None
364
+ ) -> tuple[bool, Optional[str]]:
365
+ """
366
+ Async version - Ask user for confirmation before executing a tool.
367
+
368
+ Note: This uses synchronous input() as async stdin is complex.
369
+ For production use, consider using aioconsole or similar.
370
+
371
+ Args:
372
+ tool_call_dict: The tool call dictionary from LLM response.
373
+
374
+ Returns:
375
+ Tuple of (should_execute: bool, user_input: Optional[str]).
376
+ """
377
+ # For simplicity, we use the sync version in async context
378
+ # In production, use asyncio.to_thread or aioconsole
379
+ import asyncio
380
+
381
+ return await asyncio.to_thread(
382
+ self._confirm_tool_execution, tool_call_dict, llm_content
383
+ )
384
+
385
+ def _process_tool_calls(
386
+ self, response: LLMResponse
387
+ ) -> Union[list[Message], tuple[bool, str]]:
388
+ """
389
+ Execute tool calls from response and return tool messages.
390
+
391
+ Args:
392
+ response: The LLM response containing tool calls.
393
+
394
+ Returns:
395
+ List of Message objects representing tool results, or
396
+ Tuple of (False, user_input) if user declined tool execution.
397
+ """
398
+ tool_messages = []
399
+ if response.has_tool_calls:
400
+ for tool_call_dict in response.tool_calls:
401
+ # Check for interrupt before tool execution
402
+ if self.interrupt_before_tool:
403
+ should_execute, user_input = self._confirm_tool_execution(
404
+ tool_call_dict, response.content
405
+ )
406
+ if not should_execute:
407
+ # Return the user's elaboration to be processed
408
+ return (False, user_input)
409
+
410
+ # Need to convert dict to ToolCall object or handle manually
411
+ # ToolRegistry.run takes ToolCall
412
+ from kader.tools.base import ToolCall
413
+
414
+ # Create ToolCall object
415
+ # Some providers might differ in specific dict keys, relying on normalization
416
+ try:
417
+ tool_call = ToolCall(
418
+ id=tool_call_dict.get("id", ""),
419
+ name=tool_call_dict.get("function", {}).get("name", ""),
420
+ arguments=tool_call_dict.get("function", {}).get(
421
+ "arguments", {}
422
+ ),
423
+ raw_arguments=str(
424
+ tool_call_dict.get("function", {}).get("arguments", {})
425
+ ),
426
+ )
427
+ except Exception:
428
+ # Fallback or simplified parsing if structure differs
429
+ tool_call = ToolCall(
430
+ id=tool_call_dict.get("id", ""),
431
+ name=tool_call_dict.get("function", {}).get("name", ""),
432
+ arguments={}, # Error case
433
+ )
434
+
435
+ # Execute tool
436
+ tool_result = self._tool_registry.run(tool_call)
437
+
438
+ # add result to memory
439
+ # But here we just return messages, caller handles memory add
440
+ tool_msg = Message.tool(
441
+ tool_call_id=tool_result.tool_call_id, content=tool_result.content
442
+ )
443
+ tool_messages.append(tool_msg)
444
+
445
+ return tool_messages
446
+
447
+ async def _aprocess_tool_calls(
448
+ self, response: LLMResponse
449
+ ) -> Union[list[Message], tuple[bool, str]]:
450
+ """
451
+ Async version of _process_tool_calls.
452
+
453
+ Returns:
454
+ List of Message objects representing tool results, or
455
+ Tuple of (False, user_input) if user declined tool execution.
456
+ """
457
+ tool_messages = []
458
+ if response.has_tool_calls:
459
+ for tool_call_dict in response.tool_calls:
460
+ # Check for interrupt before tool execution
461
+ if self.interrupt_before_tool:
462
+ should_execute, user_input = await self._aconfirm_tool_execution(
463
+ tool_call_dict, response.content
464
+ )
465
+ if not should_execute:
466
+ return (False, user_input)
467
+
468
+ from kader.tools.base import ToolCall
469
+
470
+ # Check structure - Ollama/OpenAI usually: {'id':..., 'type': 'function', 'function': {'name':.., 'arguments':..}}
471
+ fn_info = tool_call_dict.get("function", {})
472
+ if not fn_info and "name" in tool_call_dict:
473
+ # Handle flat structure if any
474
+ fn_info = tool_call_dict
475
+
476
+ tool_call = ToolCall(
477
+ id=tool_call_dict.get("id", "call_default"),
478
+ name=fn_info.get("name", ""),
479
+ arguments=fn_info.get("arguments", {}),
480
+ )
481
+
482
+ # Execute tool async
483
+ tool_result = await self._tool_registry.arun(tool_call)
484
+
485
+ tool_msg = Message.tool(
486
+ tool_call_id=tool_result.tool_call_id, content=tool_result.content
487
+ )
488
+ tool_messages.append(tool_msg)
489
+
490
+ return tool_messages
491
+
492
+ # -------------------------------------------------------------------------
493
+ # Synchronous Methods
494
+ # -------------------------------------------------------------------------
495
+
496
+ def invoke(
497
+ self, messages: Union[str, list[Message]], config: Optional[ModelConfig] = None
498
+ ) -> LLMResponse:
499
+ """
500
+ Synchronously invoke the agent.
501
+
502
+ Handles message preparation, LLM invocation with retries, and tool execution loop.
503
+ """
504
+ # Retry decorator wrapper logic
505
+ # Since tenacity decorators wrap functions, we define an inner function or use the decorator on a method
506
+ # but we want dynamic retry attempts (from self) which decorators strictly speaking don't support easily without specialized usage.
507
+ # We will use the functional API of tenacity for dynamic configuration.
508
+ from tenacity import Retrying
509
+
510
+ runner = Retrying(
511
+ stop=stop_after_attempt(self.retry_attempts),
512
+ wait=wait_exponential(multiplier=1, min=4, max=10),
513
+ reraise=True,
514
+ )
515
+
516
+ final_response = None
517
+
518
+ # Main Agent Loop (Limit turns to avoid infinite loops)
519
+ max_turns = 10
520
+ current_turn = 0
521
+
522
+ while current_turn < max_turns:
523
+ current_turn += 1
524
+
525
+ # Prepare full context
526
+ full_history = self._prepare_messages(messages if current_turn == 1 else [])
527
+ # Note: _prepare_messages adds input to memory. On subsequent turns (tools),
528
+ # we don't re-add the user input. self.memory already has it + previous turns.
529
+
530
+ # Call LLM with retry
531
+ try:
532
+ response = runner(
533
+ self.provider.invoke, full_history, self._get_run_config(config)
534
+ )
535
+ except RetryError as e:
536
+ # Should not happen with reraise=True, but just in case
537
+ raise e
538
+
539
+ # Add assistant response to memory
540
+ self.memory.add_message(response.to_message())
541
+
542
+ # Log the interaction if logger is active
543
+ if self.logger_id:
544
+ # Extract token usage info if available
545
+ token_usage = None
546
+ if hasattr(response, "usage"):
547
+ token_usage = {
548
+ "prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
549
+ "completion_tokens": getattr(
550
+ response.usage, "completion_tokens", 0
551
+ ),
552
+ "total_tokens": getattr(response.usage, "total_tokens", 0),
553
+ }
554
+
555
+ # Log the LLM response
556
+ agent_logger.log_llm_response(self.logger_id, str(response.content))
557
+
558
+ # Log token usage and calculate cost
559
+ if token_usage:
560
+ agent_logger.log_token_usage(
561
+ self.logger_id,
562
+ token_usage["prompt_tokens"],
563
+ token_usage["completion_tokens"],
564
+ token_usage["total_tokens"],
565
+ )
566
+
567
+ # estimate the cost...
568
+ estimated_cost = self.provider.estimate_cost(token_usage)
569
+
570
+ # Calculate and log cost
571
+ agent_logger.calculate_cost(
572
+ self.logger_id,
573
+ estimated_cost.total_cost,
574
+ )
575
+
576
+ # Save session update
577
+ if self.use_persistence:
578
+ self._save_session()
579
+
580
+ # Check for tool calls
581
+ if response.has_tool_calls:
582
+ tool_result = self._process_tool_calls(response)
583
+
584
+ # Check if user declined tool execution
585
+ if isinstance(tool_result, tuple) and tool_result[0] is False:
586
+ # User declined - add their input as a new message and continue
587
+ user_elaboration = tool_result[1]
588
+ if user_elaboration:
589
+ self.memory.add_message(Message.user(user_elaboration))
590
+ else:
591
+ # User provided no elaboration, return current response
592
+ final_response = response
593
+ break
594
+ continue
595
+
596
+ tool_msgs = tool_result
597
+
598
+ # Add tool outputs to memory
599
+ for tm in tool_msgs:
600
+ self.memory.add_message(tm)
601
+
602
+ # Log tool usage
603
+ if self.logger_id:
604
+ # Extract tool name and arguments
605
+ tool_name = "unknown"
606
+ arguments = {}
607
+ if hasattr(tm, "tool_call_id"):
608
+ # This is a tool message, need to find the tool name
609
+ # We'll check the original response to find the tool
610
+ for tool_call in response.tool_calls:
611
+ fn_info = tool_call.get("function", {})
612
+ if fn_info.get("name"):
613
+ tool_name = fn_info.get("name", "unknown")
614
+ arguments = fn_info.get("arguments", {})
615
+ agent_logger.log_tool_usage(
616
+ self.logger_id, tool_name, arguments
617
+ )
618
+ break
619
+
620
+ # Save session update after tool results
621
+ if self.use_persistence:
622
+ self._save_session()
623
+
624
+ # Loop continues to feed tool outputs back to LLM
625
+ continue
626
+ else:
627
+ # No tools, final response
628
+ final_response = response
629
+ break
630
+
631
+ return final_response
632
+
633
+ def stream(
634
+ self, messages: Union[str, list[Message]], config: Optional[ModelConfig] = None
635
+ ) -> Iterator[StreamChunk]:
636
+ """
637
+ Synchronously stream the agent response.
638
+
639
+ Note: Tool execution breaks streaming flow typically.
640
+ If tools are called, we consume the stream to execute tools, then stream the final answer.
641
+ """
642
+ # For simplicity in this base implementation, we'll only stream if there are no tool calls initially,
643
+ # or we buffer if we detect tools. Logic can get complex.
644
+
645
+ # Current simplified approach:
646
+ # 1. Prepare messages
647
+ full_history = self._prepare_messages(messages)
648
+
649
+ # 2. Stream from provider
650
+ # We need to handle retries for the *start* of the stream
651
+ from tenacity import Retrying
652
+
653
+ runner = Retrying(
654
+ stop=stop_after_attempt(self.retry_attempts),
655
+ wait=wait_exponential(multiplier=1, min=4, max=10),
656
+ reraise=True,
657
+ )
658
+
659
+ # We can't retry the *iteration* easily if it fails mid-stream without complex logic.
660
+ # We will retry obtaining the iterator.
661
+ stream_iterator = runner(
662
+ self.provider.stream, full_history, self._get_run_config(config)
663
+ )
664
+
665
+ yield from stream_iterator
666
+
667
+ # Update session at end if needed
668
+ # Note: Streaming complicates memory/persistence because getting the full message
669
+ # requires aggregating chunks. The current implementation of base.stream DOES NOT
670
+ # auto-aggregate into memory (it just yields).
671
+ # The USER of stream() is responsible for re-assembling the message and adding to memory
672
+ # if they want history.
673
+ # BUT, wait. _prepare_messages DOES add input messages to memory.
674
+ # The RESPONSE is not added here.
675
+ # TODO: A robust stream implementation should aggregate and save.
676
+ # For now, we only save the input part since _prepare_messages called it.
677
+ if self.use_persistence:
678
+ self._save_session()
679
+
680
+ # -------------------------------------------------------------------------
681
+ # Asynchronous Methods
682
+ # -------------------------------------------------------------------------
683
+
684
+ async def ainvoke(
685
+ self, messages: Union[str, list[Message]], config: Optional[ModelConfig] = None
686
+ ) -> LLMResponse:
687
+ """Asynchronous invocation with retries and tool loop."""
688
+ from tenacity import AsyncRetrying
689
+
690
+ runner = AsyncRetrying(
691
+ stop=stop_after_attempt(self.retry_attempts),
692
+ wait=wait_exponential(multiplier=1, min=4, max=10),
693
+ reraise=True,
694
+ )
695
+
696
+ max_turns = 10
697
+ current_turn = 0
698
+ final_response = None
699
+
700
+ while current_turn < max_turns:
701
+ current_turn += 1
702
+ full_history = self._prepare_messages(messages if current_turn == 1 else [])
703
+
704
+ response = await runner(
705
+ self.provider.ainvoke, full_history, self._get_run_config(config)
706
+ )
707
+
708
+ self.memory.add_message(response.to_message())
709
+
710
+ # Log the interaction if logger is active
711
+ if self.logger_id:
712
+ # Extract token usage info if available
713
+ token_usage = None
714
+ if hasattr(response, "usage"):
715
+ token_usage = {
716
+ "prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
717
+ "completion_tokens": getattr(
718
+ response.usage, "completion_tokens", 0
719
+ ),
720
+ "total_tokens": getattr(response.usage, "total_tokens", 0),
721
+ }
722
+
723
+ # Log the LLM response
724
+ agent_logger.log_llm_response(self.logger_id, str(response.content))
725
+
726
+ # Log token usage and calculate cost
727
+ if token_usage:
728
+ agent_logger.log_token_usage(
729
+ self.logger_id,
730
+ token_usage["prompt_tokens"],
731
+ token_usage["completion_tokens"],
732
+ token_usage["total_tokens"],
733
+ )
734
+
735
+ # estimate the cost...
736
+ estimated_cost = self.provider.estimate_cost(token_usage)
737
+
738
+ # Calculate and log cost
739
+ agent_logger.calculate_cost(
740
+ self.logger_id,
741
+ estimated_cost.total_cost,
742
+ )
743
+
744
+ # Save session update
745
+ if self.use_persistence:
746
+ self._save_session()
747
+
748
+ if response.has_tool_calls:
749
+ tool_result = await self._aprocess_tool_calls(response)
750
+
751
+ # Check if user declined tool execution
752
+ if isinstance(tool_result, tuple) and tool_result[0] is False:
753
+ # User declined - add their input as a new message and continue
754
+ user_elaboration = tool_result[1]
755
+ if user_elaboration:
756
+ self.memory.add_message(Message.user(user_elaboration))
757
+ else:
758
+ final_response = response
759
+ break
760
+ continue
761
+
762
+ tool_msgs = tool_result
763
+
764
+ for tm in tool_msgs:
765
+ self.memory.add_message(tm)
766
+
767
+ # Log tool usage
768
+ if self.logger_id:
769
+ # Extract tool name and arguments
770
+ tool_name = "unknown"
771
+ arguments = {}
772
+ if hasattr(tm, "tool_call_id"):
773
+ # This is a tool message, need to find the tool name
774
+ # We'll check the original response to find the tool
775
+ for tool_call in response.tool_calls:
776
+ fn_info = tool_call.get("function", {})
777
+ if fn_info.get("name"):
778
+ tool_name = fn_info.get("name", "unknown")
779
+ arguments = fn_info.get("arguments", {})
780
+ agent_logger.log_tool_usage(
781
+ self.logger_id, tool_name, arguments
782
+ )
783
+ break
784
+
785
+ # Save session update
786
+ if self.use_persistence:
787
+ self._save_session()
788
+ continue
789
+ else:
790
+ final_response = response
791
+ break
792
+
793
+ return final_response
794
+
795
+ async def astream(
796
+ self, messages: Union[str, list[Message]], config: Optional[ModelConfig] = None
797
+ ) -> AsyncIterator[StreamChunk]:
798
+ """Asynchronous streaming with memory aggregation."""
799
+ # Prepare messages
800
+ full_history = self._prepare_messages(messages)
801
+
802
+ # Determine config
803
+ run_config = self._get_run_config(config)
804
+
805
+ # Get stream iterator directly (cannot use tenacity on async generator creation easily)
806
+ stream_iterator = self.provider.astream(full_history, run_config)
807
+
808
+ aggregated_content = ""
809
+ aggregated_tool_calls = []
810
+
811
+ async for chunk in stream_iterator:
812
+ aggregated_content += chunk.content
813
+ if chunk.tool_calls:
814
+ # TODO: robust tool call aggregation if streaming partial JSON
815
+ # For now, assume provider yields complete tool calls in chunks or we just collect them
816
+ aggregated_tool_calls.extend(chunk.tool_calls)
817
+ yield chunk
818
+
819
+ # Create Message and add to memory
820
+ # Note: If no content and no tools, we don't add (or adds empty message)
821
+
822
+ # If we have tool calls, we might need to properly format them
823
+ final_msg = Message(
824
+ role="assistant",
825
+ content=aggregated_content,
826
+ tool_calls=aggregated_tool_calls if aggregated_tool_calls else None,
827
+ )
828
+
829
+ self.memory.add_message(final_msg)
830
+
831
+ if self.use_persistence:
832
+ self._save_session()
833
+
834
+ # -------------------------------------------------------------------------
835
+ # Serialization Methods
836
+ # -------------------------------------------------------------------------
837
+
838
+ def to_yaml(self, path: Union[str, Path]) -> None:
839
+ """
840
+ Serialize agent configuration to YAML.
841
+
842
+ Args:
843
+ path: File path to save YAML.
844
+ """
845
+ system_prompt_str = (
846
+ self.system_prompt.resolve_prompt()
847
+ if isinstance(self.system_prompt, PromptBase)
848
+ else str(self.system_prompt)
849
+ )
850
+ data = {
851
+ "name": self.name,
852
+ "system_prompt": system_prompt_str,
853
+ "retry_attempts": self.retry_attempts,
854
+ "provider": {
855
+ "model": self.provider.model,
856
+ # Add other provider settings if possible
857
+ },
858
+ "tools": self._tool_registry.names,
859
+ # Memory state could be saved here too, but usually configured separately
860
+ }
861
+
862
+ path_obj = Path(path)
863
+ path_obj.parent.mkdir(parents=True, exist_ok=True)
864
+
865
+ with open(path_obj, "w", encoding="utf-8") as f:
866
+ yaml.dump(data, f, sort_keys=False, default_flow_style=False)
867
+
868
+ @classmethod
869
+ def from_yaml(
870
+ cls, path: Union[str, Path], tool_registry: Optional[ToolRegistry] = None
871
+ ) -> "BaseAgent":
872
+ """
873
+ Load agent from YAML configuration.
874
+
875
+ Args:
876
+ path: Path to YAML file.
877
+ tool_registry: Registry containing *available* tools to re-hydrate the agent.
878
+ The agent's tools will be selected from this registry based on names in YAML.
879
+
880
+ Returns:
881
+ Instantiated BaseAgent.
882
+ """
883
+ path_obj = Path(path)
884
+ with open(path_obj, "r", encoding="utf-8") as f:
885
+ data = yaml.safe_load(f)
886
+
887
+ name = data.get("name", "unnamed_agent")
888
+ system_prompt = data.get("system_prompt", "")
889
+ retry_attempts = data.get("retry_attempts", 3)
890
+ provider_config = data.get("provider", {})
891
+ model_name = provider_config.get("model", "qwen3-coder:480b-cloud")
892
+
893
+ # Handle persistence parameter
894
+ # If persistence key exists and is True, use_persistence is True
895
+ # If persistence key exists and is False, use_persistence is False
896
+ # If persistence key doesn't exist, use_persistence defaults to True
897
+ use_persistence = data.get("persistence", True)
898
+
899
+ # Reconstruct tools
900
+ tools = []
901
+ tool_names = data.get("tools", [])
902
+
903
+ # Use provided registry or fallback to default
904
+ registry = tool_registry
905
+ if registry is None:
906
+ # Lazy import to avoid circular dependencies if any
907
+ try:
908
+ from kader.tools import get_default_registry
909
+
910
+ registry = get_default_registry()
911
+ except ImportError:
912
+ pass
913
+
914
+ if tool_names and registry:
915
+ for t_name in tool_names:
916
+ t = registry.get(t_name)
917
+ if t:
918
+ tools.append(t)
919
+
920
+ return cls(
921
+ name=name,
922
+ system_prompt=system_prompt,
923
+ tools=tools,
924
+ retry_attempts=retry_attempts,
925
+ model_name=model_name,
926
+ use_persistence=use_persistence,
927
+ )