kader 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kader/agent/base.py ADDED
@@ -0,0 +1,920 @@
1
+ """
2
+ Base Agent Implementation.
3
+
4
+ Defines the BaseAgent class which serves as the foundation for creating specific agents
5
+ with tools, memory, and LLM provider integration.
6
+ """
7
+
8
+ from pathlib import Path
9
+ from typing import AsyncIterator, Iterator, Optional, Union
10
+
11
+ import yaml
12
+ from tenacity import RetryError, stop_after_attempt, wait_exponential
13
+
14
+ from kader.memory import (
15
+ ConversationManager,
16
+ FileSessionManager,
17
+ SlidingWindowConversationManager,
18
+ )
19
+ from kader.prompts.base import PromptBase
20
+ from kader.providers.base import (
21
+ BaseLLMProvider,
22
+ LLMResponse,
23
+ Message,
24
+ ModelConfig,
25
+ StreamChunk,
26
+ )
27
+ from kader.providers.ollama import OllamaProvider
28
+ from kader.tools import BaseTool, ToolRegistry
29
+
30
+ from .logger import agent_logger
31
+
32
+
33
+ class BaseAgent:
34
+ """
35
+ Base class for Agents.
36
+
37
+ Combines tools, memory, and an LLM provider to perform tasks.
38
+ Supports synchronous and asynchronous invocation and streaming.
39
+ Includes built-in retry logic using tenacity.
40
+ Supports session persistence via FileSessionManager.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ name: str,
46
+ system_prompt: Union[str, PromptBase],
47
+ tools: Union[list[BaseTool], ToolRegistry, None] = None,
48
+ provider: Optional[BaseLLMProvider] = None,
49
+ memory: Optional[ConversationManager] = None,
50
+ retry_attempts: int = 3,
51
+ model_name: str = "qwen3-coder:480b-cloud",
52
+ session_id: Optional[str] = None,
53
+ use_persistence: bool = False,
54
+ interrupt_before_tool: bool = True,
55
+ tool_confirmation_callback: Optional[callable] = None,
56
+ ) -> None:
57
+ """
58
+ Initialize the Base Agent.
59
+
60
+ Args:
61
+ name: Name of the agent.
62
+ system_prompt: The system prompt definition.
63
+ tools: List of tools or a ToolRegistry.
64
+ provider: LLM provider instance. If None, uses OllamaProvider.
65
+ memory: Conversation/Memory manager. If None, uses SlidingWindowConversationManager.
66
+ retry_attempts: Number of retry attempts for LLM calls (default: 3).
67
+ model_name: Default model name if creating a default Ollama provider.
68
+ session_id: Optional session ID to load/resume.
69
+ use_persistence: If True, enables session persistence (auto-enabled if session_id provided).
70
+ interrupt_before_tool: If True, pauses and asks for user confirmation before executing tools.
71
+ tool_confirmation_callback: Optional callback function for tool confirmation.
72
+ Signature: (message: str) -> tuple[bool, Optional[str]]
73
+ Returns (should_execute, user_elaboration_if_declined).
74
+ """
75
+ self.name = name
76
+ self.system_prompt = system_prompt
77
+ self.retry_attempts = retry_attempts
78
+ self.interrupt_before_tool = interrupt_before_tool
79
+ self.tool_confirmation_callback = tool_confirmation_callback
80
+
81
+ # Persistence Configuration
82
+ self.session_id = session_id
83
+ self.use_persistence = use_persistence or (session_id is not None)
84
+ self.session_manager = FileSessionManager() if self.use_persistence else None
85
+
86
+ # Initialize Logger if agent uses persistence (logs only if there's a session)
87
+ self.logger_id = None
88
+ if self.use_persistence:
89
+ # Only create logger if we have a session_id or if the session manager will create one
90
+ session_id_for_logger = self.session_id
91
+ if not session_id_for_logger and self.session_manager:
92
+ # If no session_id yet but persistence is enabled, we'll get one during _load_session
93
+ pass # We'll set up the logger in _load_session if needed
94
+ if session_id_for_logger:
95
+ self.logger_id = agent_logger.setup_logger(
96
+ self.name, session_id_for_logger
97
+ )
98
+
99
+ # Initialize Provider
100
+ if provider:
101
+ self.provider = provider
102
+ else:
103
+ self.provider = OllamaProvider(model=model_name)
104
+
105
+ # Initialize Memory
106
+ if memory:
107
+ self.memory = memory
108
+ else:
109
+ self.memory = SlidingWindowConversationManager()
110
+
111
+ # Initialize Tools
112
+ self._tool_registry = ToolRegistry()
113
+ if tools:
114
+ if isinstance(tools, ToolRegistry):
115
+ self._tool_registry = tools
116
+ elif isinstance(tools, list):
117
+ for tool in tools:
118
+ self._tool_registry.register(tool)
119
+
120
+ if self.use_persistence:
121
+ self._load_session()
122
+
123
+ # Propagate session to tools
124
+ self._propagate_session_to_tools()
125
+
126
+ # Update config with tools if provider supports it
127
+ self._update_provider_tools()
128
+
129
+ def _load_session(self) -> None:
130
+ """Load conversation history from session storage."""
131
+ if not self.session_manager:
132
+ return
133
+
134
+ if not self.session_id:
135
+ session = self.session_manager.create_session(self.name)
136
+ self.session_id = session.session_id
137
+
138
+ # Initialize logger if we now have a session_id and logging hasn't been set up yet
139
+ if self.use_persistence and not self.logger_id and self.session_id:
140
+ self.logger_id = agent_logger.setup_logger(self.name, self.session_id)
141
+
142
+ # Propagate session to tools
143
+ self._propagate_session_to_tools()
144
+
145
+ # Load conversation history
146
+ try:
147
+ # We don't check if session exists first because load_conversation
148
+ # handles missing sessions by returning empty list (usually)
149
+ # or we catch the error. FileSessionManager.load_conversation returns list[dict].
150
+ history = self.session_manager.load_conversation(self.session_id)
151
+ if history:
152
+ # Add loaded messages to memory
153
+ # ConversationManager supports adding dicts directly
154
+ self.memory.add_messages(history)
155
+ except Exception:
156
+ # If session doesn't exist or error, we start fresh (or could log warning)
157
+ # For now, we silently proceed with empty memory
158
+ pass
159
+
160
+ def _propagate_session_to_tools(self) -> None:
161
+ """Propagate current session ID to all registered tools."""
162
+ if not self.session_id:
163
+ return
164
+
165
+ for tool in self._tool_registry.tools:
166
+ if hasattr(tool, "set_session_id"):
167
+ tool.set_session_id(self.session_id)
168
+
169
+ def _save_session(self) -> None:
170
+ """Save current conversation history to session storage."""
171
+ if not self.session_manager or not self.session_id:
172
+ return
173
+
174
+ try:
175
+ # Get all messages from memory
176
+ # Convert ConversationMessage to dict (using .message property)
177
+ messages = [msg.message for msg in self.memory.get_messages()]
178
+ self.session_manager.save_conversation(self.session_id, messages)
179
+ except Exception:
180
+ # Log error or handle silently? Best not to crash main flow on save failure
181
+ pass
182
+
183
+ @property
184
+ def tools_map(self) -> dict[str, BaseTool]:
185
+ """
186
+ Get a dictionary mapping tool names to tool instances.
187
+
188
+ Returns:
189
+ Dictionary of {tool_name: tool_instance}
190
+ """
191
+ # Access private attribute of ToolRegistry if needed, or iterate
192
+ # ToolRegistry has .tools property which returns a list
193
+ return {tool.name: tool for tool in self._tool_registry.tools}
194
+
195
+ def _update_provider_tools(self) -> None:
196
+ """Update the provider's default config with registered tools."""
197
+ if not self._tool_registry.tools:
198
+ return
199
+
200
+ # We need to update the default config of the provider to include these tools
201
+ # Since we can't easily modify the internal default_config of the provider cleanly
202
+ # from here without accessing protected members, strict encapsulation might prevent this.
203
+ # However, for this implementation, we will pass tools during invoke if they exist.
204
+ pass
205
+
206
+ def _get_run_config(self, config: Optional[ModelConfig] = None) -> ModelConfig:
207
+ """Prepare execution config with tools."""
208
+ base_config = config or ModelConfig()
209
+
210
+ # If tools are available and not explicitly disabled or overridden
211
+ if self._tool_registry.tools and not base_config.tools:
212
+ # Detect provider type to format tools correctly
213
+ # Defaulting to 'openai' format as it's the de-facto standard
214
+ provider_type = "openai"
215
+ if isinstance(self.provider, OllamaProvider):
216
+ provider_type = "ollama"
217
+
218
+ base_config = ModelConfig(
219
+ temperature=base_config.temperature,
220
+ max_tokens=base_config.max_tokens,
221
+ top_p=base_config.top_p,
222
+ top_k=base_config.top_k,
223
+ frequency_penalty=base_config.frequency_penalty,
224
+ stop_sequences=base_config.stop_sequences,
225
+ stream=base_config.stream,
226
+ tools=self._tool_registry.to_provider_format(provider_type),
227
+ tool_choice=base_config.tool_choice,
228
+ extra=base_config.extra,
229
+ )
230
+
231
+ return base_config
232
+
233
+ def _prepare_messages(
234
+ self, messages: Union[str, list[Message], list[dict]]
235
+ ) -> list[Message]:
236
+ """Prepare messages adding system prompt and history."""
237
+ # Normalize input to list of Message objects
238
+ input_msgs: list[Message] = []
239
+ if isinstance(messages, str):
240
+ input_msgs = [Message.user(messages)]
241
+ elif isinstance(messages, list):
242
+ if not messages:
243
+ pass
244
+ elif isinstance(messages[0], dict):
245
+ # Convert dicts to Messages
246
+ pass # simplified for now, assuming user passes Message objects or string
247
+ # But we should handle it better
248
+ input_msgs = [
249
+ Message(**msg) if isinstance(msg, dict) else msg for msg in messages
250
+ ]
251
+ else:
252
+ input_msgs = messages # Assuming list[Message]
253
+
254
+ # Add to memory
255
+ for msg in input_msgs:
256
+ self.memory.add_message(msg)
257
+
258
+ # Retrieve context (system prompt + history)
259
+ # 1. Start with System Prompt
260
+ if isinstance(self.system_prompt, PromptBase):
261
+ sys_prompt_content = self.system_prompt.resolve_prompt()
262
+ else:
263
+ sys_prompt_content = str(self.system_prompt)
264
+
265
+ final_messages = [Message.system(sys_prompt_content)]
266
+
267
+ # 2. Get history from memory (windowed)
268
+ # memory.apply_window() returns list[dict], need to convert back to Message
269
+ history_dicts = self.memory.apply_window()
270
+
271
+ # We need to act smart here. invoke/stream usually take the *new* messages
272
+ # plus history. Memory managers usually store everything.
273
+ # If we added input_msgs to memory, apply_window should return them too if relevant.
274
+ # So we just use what Memory gives us.
275
+
276
+ for msg_dict in history_dicts:
277
+ # Basic conversion from dict back to Message
278
+ # Note: conversation.py Message support might be limited to dicts
279
+ msg = Message(
280
+ role=msg_dict.get("role"),
281
+ content=msg_dict.get("content"),
282
+ name=msg_dict.get("name"),
283
+ tool_call_id=msg_dict.get("tool_call_id"),
284
+ tool_calls=msg_dict.get("tool_calls"),
285
+ )
286
+ final_messages.append(msg)
287
+
288
+ return final_messages
289
+
290
+ def _format_tool_call_for_display(self, tool_call_dict: dict) -> str:
291
+ """
292
+ Format a tool call for display to the user.
293
+
294
+ Args:
295
+ tool_call_dict: The tool call dictionary from LLM response.
296
+
297
+ Returns:
298
+ The tool's interruption message.
299
+ """
300
+ import json
301
+
302
+ fn_info = tool_call_dict.get("function", {})
303
+ if not fn_info and "name" in tool_call_dict:
304
+ fn_info = tool_call_dict
305
+
306
+ tool_name = fn_info.get("name", "unknown")
307
+ arguments = fn_info.get("arguments", {})
308
+
309
+ # Parse arguments if string
310
+ if isinstance(arguments, str):
311
+ try:
312
+ arguments = json.loads(arguments)
313
+ except json.JSONDecodeError:
314
+ pass
315
+
316
+ # Get the tool's interruption message if available
317
+ tool = self._tool_registry.get(tool_name)
318
+ if tool and isinstance(arguments, dict):
319
+ return tool.get_interruption_message(**arguments)
320
+
321
+ # Fallback for unknown tools
322
+ return f"execute {tool_name}"
323
+
324
+ def _confirm_tool_execution(
325
+ self, tool_call_dict: dict
326
+ ) -> tuple[bool, Optional[str]]:
327
+ """
328
+ Ask user for confirmation before executing a tool.
329
+
330
+ Args:
331
+ tool_call_dict: The tool call dictionary from LLM response.
332
+
333
+ Returns:
334
+ Tuple of (should_execute: bool, user_input: Optional[str]).
335
+ If should_execute is False, user_input contains additional context.
336
+ """
337
+ display_str = self._format_tool_call_for_display(tool_call_dict)
338
+
339
+ # Use callback if provided (e.g., for GUI/TUI)
340
+ if self.tool_confirmation_callback:
341
+ return self.tool_confirmation_callback(display_str)
342
+
343
+ # Default: use console input
344
+ print(display_str)
345
+
346
+ while True:
347
+ user_input = input("\nExecute this tool? (yes/no): ").strip().lower()
348
+
349
+ if user_input in ("yes", "y"):
350
+ return True, None
351
+ elif user_input in ("no", "n"):
352
+ elaboration = input(
353
+ "Please provide more context or instructions: "
354
+ ).strip()
355
+ return False, elaboration if elaboration else None
356
+ else:
357
+ print("Please enter 'yes' or 'no'.")
358
+
359
+ async def _aconfirm_tool_execution(
360
+ self, tool_call_dict: dict
361
+ ) -> tuple[bool, Optional[str]]:
362
+ """
363
+ Async version - Ask user for confirmation before executing a tool.
364
+
365
+ Note: This uses synchronous input() as async stdin is complex.
366
+ For production use, consider using aioconsole or similar.
367
+
368
+ Args:
369
+ tool_call_dict: The tool call dictionary from LLM response.
370
+
371
+ Returns:
372
+ Tuple of (should_execute: bool, user_input: Optional[str]).
373
+ """
374
+ # For simplicity, we use the sync version in async context
375
+ # In production, use asyncio.to_thread or aioconsole
376
+ import asyncio
377
+
378
+ return await asyncio.to_thread(self._confirm_tool_execution, tool_call_dict)
379
+
380
+ def _process_tool_calls(
381
+ self, response: LLMResponse
382
+ ) -> Union[list[Message], tuple[bool, str]]:
383
+ """
384
+ Execute tool calls from response and return tool messages.
385
+
386
+ Args:
387
+ response: The LLM response containing tool calls.
388
+
389
+ Returns:
390
+ List of Message objects representing tool results, or
391
+ Tuple of (False, user_input) if user declined tool execution.
392
+ """
393
+ tool_messages = []
394
+ if response.has_tool_calls:
395
+ for tool_call_dict in response.tool_calls:
396
+ # Check for interrupt before tool execution
397
+ if self.interrupt_before_tool:
398
+ should_execute, user_input = self._confirm_tool_execution(
399
+ tool_call_dict
400
+ )
401
+ if not should_execute:
402
+ # Return the user's elaboration to be processed
403
+ return (False, user_input)
404
+
405
+ # Need to convert dict to ToolCall object or handle manually
406
+ # ToolRegistry.run takes ToolCall
407
+ from kader.tools.base import ToolCall
408
+
409
+ # Create ToolCall object
410
+ # Some providers might differ in specific dict keys, relying on normalization
411
+ try:
412
+ tool_call = ToolCall(
413
+ id=tool_call_dict.get("id", ""),
414
+ name=tool_call_dict.get("function", {}).get("name", ""),
415
+ arguments=tool_call_dict.get("function", {}).get(
416
+ "arguments", {}
417
+ ),
418
+ raw_arguments=str(
419
+ tool_call_dict.get("function", {}).get("arguments", {})
420
+ ),
421
+ )
422
+ except Exception:
423
+ # Fallback or simplified parsing if structure differs
424
+ tool_call = ToolCall(
425
+ id=tool_call_dict.get("id", ""),
426
+ name=tool_call_dict.get("function", {}).get("name", ""),
427
+ arguments={}, # Error case
428
+ )
429
+
430
+ # Execute tool
431
+ tool_result = self._tool_registry.run(tool_call)
432
+
433
+ # add result to memory
434
+ # But here we just return messages, caller handles memory add
435
+ tool_msg = Message.tool(
436
+ tool_call_id=tool_result.tool_call_id, content=tool_result.content
437
+ )
438
+ tool_messages.append(tool_msg)
439
+
440
+ return tool_messages
441
+
442
+ async def _aprocess_tool_calls(
443
+ self, response: LLMResponse
444
+ ) -> Union[list[Message], tuple[bool, str]]:
445
+ """
446
+ Async version of _process_tool_calls.
447
+
448
+ Returns:
449
+ List of Message objects representing tool results, or
450
+ Tuple of (False, user_input) if user declined tool execution.
451
+ """
452
+ tool_messages = []
453
+ if response.has_tool_calls:
454
+ for tool_call_dict in response.tool_calls:
455
+ # Check for interrupt before tool execution
456
+ if self.interrupt_before_tool:
457
+ should_execute, user_input = await self._aconfirm_tool_execution(
458
+ tool_call_dict
459
+ )
460
+ if not should_execute:
461
+ return (False, user_input)
462
+
463
+ from kader.tools.base import ToolCall
464
+
465
+ # Check structure - Ollama/OpenAI usually: {'id':..., 'type': 'function', 'function': {'name':.., 'arguments':..}}
466
+ fn_info = tool_call_dict.get("function", {})
467
+ if not fn_info and "name" in tool_call_dict:
468
+ # Handle flat structure if any
469
+ fn_info = tool_call_dict
470
+
471
+ tool_call = ToolCall(
472
+ id=tool_call_dict.get("id", "call_default"),
473
+ name=fn_info.get("name", ""),
474
+ arguments=fn_info.get("arguments", {}),
475
+ )
476
+
477
+ # Execute tool async
478
+ tool_result = await self._tool_registry.arun(tool_call)
479
+
480
+ tool_msg = Message.tool(
481
+ tool_call_id=tool_result.tool_call_id, content=tool_result.content
482
+ )
483
+ tool_messages.append(tool_msg)
484
+
485
+ return tool_messages
486
+
487
+ # -------------------------------------------------------------------------
488
+ # Synchronous Methods
489
+ # -------------------------------------------------------------------------
490
+
491
+ def invoke(
492
+ self, messages: Union[str, list[Message]], config: Optional[ModelConfig] = None
493
+ ) -> LLMResponse:
494
+ """
495
+ Synchronously invoke the agent.
496
+
497
+ Handles message preparation, LLM invocation with retries, and tool execution loop.
498
+ """
499
+ # Retry decorator wrapper logic
500
+ # Since tenacity decorators wrap functions, we define an inner function or use the decorator on a method
501
+ # but we want dynamic retry attempts (from self) which decorators strictly speaking don't support easily without specialized usage.
502
+ # We will use the functional API of tenacity for dynamic configuration.
503
+ from tenacity import Retrying
504
+
505
+ runner = Retrying(
506
+ stop=stop_after_attempt(self.retry_attempts),
507
+ wait=wait_exponential(multiplier=1, min=4, max=10),
508
+ reraise=True,
509
+ )
510
+
511
+ final_response = None
512
+
513
+ # Main Agent Loop (Limit turns to avoid infinite loops)
514
+ max_turns = 10
515
+ current_turn = 0
516
+
517
+ while current_turn < max_turns:
518
+ current_turn += 1
519
+
520
+ # Prepare full context
521
+ full_history = self._prepare_messages(messages if current_turn == 1 else [])
522
+ # Note: _prepare_messages adds input to memory. On subsequent turns (tools),
523
+ # we don't re-add the user input. self.memory already has it + previous turns.
524
+
525
+ # Call LLM with retry
526
+ try:
527
+ response = runner(
528
+ self.provider.invoke, full_history, self._get_run_config(config)
529
+ )
530
+ except RetryError as e:
531
+ # Should not happen with reraise=True, but just in case
532
+ raise e
533
+
534
+ # Add assistant response to memory
535
+ self.memory.add_message(response.to_message())
536
+
537
+ # Log the interaction if logger is active
538
+ if self.logger_id:
539
+ # Extract token usage info if available
540
+ token_usage = None
541
+ if hasattr(response, "usage"):
542
+ token_usage = {
543
+ "prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
544
+ "completion_tokens": getattr(
545
+ response.usage, "completion_tokens", 0
546
+ ),
547
+ "total_tokens": getattr(response.usage, "total_tokens", 0),
548
+ }
549
+
550
+ # Log the LLM response
551
+ agent_logger.log_llm_response(self.logger_id, str(response.content))
552
+
553
+ # Log token usage and calculate cost
554
+ if token_usage:
555
+ agent_logger.log_token_usage(
556
+ self.logger_id,
557
+ token_usage["prompt_tokens"],
558
+ token_usage["completion_tokens"],
559
+ token_usage["total_tokens"],
560
+ )
561
+
562
+ # Calculate and log cost
563
+ agent_logger.calculate_cost(
564
+ self.logger_id,
565
+ token_usage["prompt_tokens"],
566
+ token_usage["completion_tokens"],
567
+ getattr(self.provider, "model", ""),
568
+ )
569
+
570
+ # Save session update
571
+ if self.use_persistence:
572
+ self._save_session()
573
+
574
+ # Check for tool calls
575
+ if response.has_tool_calls:
576
+ tool_result = self._process_tool_calls(response)
577
+
578
+ # Check if user declined tool execution
579
+ if isinstance(tool_result, tuple) and tool_result[0] is False:
580
+ # User declined - add their input as a new message and continue
581
+ user_elaboration = tool_result[1]
582
+ if user_elaboration:
583
+ self.memory.add_message(Message.user(user_elaboration))
584
+ else:
585
+ # User provided no elaboration, return current response
586
+ final_response = response
587
+ break
588
+ continue
589
+
590
+ tool_msgs = tool_result
591
+
592
+ # Add tool outputs to memory
593
+ for tm in tool_msgs:
594
+ self.memory.add_message(tm)
595
+
596
+ # Log tool usage
597
+ if self.logger_id:
598
+ # Extract tool name and arguments
599
+ tool_name = "unknown"
600
+ arguments = {}
601
+ if hasattr(tm, "tool_call_id"):
602
+ # This is a tool message, need to find the tool name
603
+ # We'll check the original response to find the tool
604
+ for tool_call in response.tool_calls:
605
+ fn_info = tool_call.get("function", {})
606
+ if fn_info.get("name"):
607
+ tool_name = fn_info.get("name", "unknown")
608
+ arguments = fn_info.get("arguments", {})
609
+ agent_logger.log_tool_usage(
610
+ self.logger_id, tool_name, arguments
611
+ )
612
+ break
613
+
614
+ # Save session update after tool results
615
+ if self.use_persistence:
616
+ self._save_session()
617
+
618
+ # Loop continues to feed tool outputs back to LLM
619
+ continue
620
+ else:
621
+ # No tools, final response
622
+ final_response = response
623
+ break
624
+
625
+ return final_response
626
+
627
+ def stream(
628
+ self, messages: Union[str, list[Message]], config: Optional[ModelConfig] = None
629
+ ) -> Iterator[StreamChunk]:
630
+ """
631
+ Synchronously stream the agent response.
632
+
633
+ Note: Tool execution breaks streaming flow typically.
634
+ If tools are called, we consume the stream to execute tools, then stream the final answer.
635
+ """
636
+ # For simplicity in this base implementation, we'll only stream if there are no tool calls initially,
637
+ # or we buffer if we detect tools. Logic can get complex.
638
+
639
+ # Current simplified approach:
640
+ # 1. Prepare messages
641
+ full_history = self._prepare_messages(messages)
642
+
643
+ # 2. Stream from provider
644
+ # We need to handle retries for the *start* of the stream
645
+ from tenacity import Retrying
646
+
647
+ runner = Retrying(
648
+ stop=stop_after_attempt(self.retry_attempts),
649
+ wait=wait_exponential(multiplier=1, min=4, max=10),
650
+ reraise=True,
651
+ )
652
+
653
+ # We can't retry the *iteration* easily if it fails mid-stream without complex logic.
654
+ # We will retry obtaining the iterator.
655
+ stream_iterator = runner(
656
+ self.provider.stream, full_history, self._get_run_config(config)
657
+ )
658
+
659
+ yield from stream_iterator
660
+
661
+ # Update session at end if needed
662
+ # Note: Streaming complicates memory/persistence because getting the full message
663
+ # requires aggregating chunks. The current implementation of base.stream DOES NOT
664
+ # auto-aggregate into memory (it just yields).
665
+ # The USER of stream() is responsible for re-assembling the message and adding to memory
666
+ # if they want history.
667
+ # BUT, wait. _prepare_messages DOES add input messages to memory.
668
+ # The RESPONSE is not added here.
669
+ # TODO: A robust stream implementation should aggregate and save.
670
+ # For now, we only save the input part since _prepare_messages called it.
671
+ if self.use_persistence:
672
+ self._save_session()
673
+
674
+ # -------------------------------------------------------------------------
675
+ # Asynchronous Methods
676
+ # -------------------------------------------------------------------------
677
+
678
+ async def ainvoke(
679
+ self, messages: Union[str, list[Message]], config: Optional[ModelConfig] = None
680
+ ) -> LLMResponse:
681
+ """Asynchronous invocation with retries and tool loop."""
682
+ from tenacity import AsyncRetrying
683
+
684
+ runner = AsyncRetrying(
685
+ stop=stop_after_attempt(self.retry_attempts),
686
+ wait=wait_exponential(multiplier=1, min=4, max=10),
687
+ reraise=True,
688
+ )
689
+
690
+ max_turns = 10
691
+ current_turn = 0
692
+ final_response = None
693
+
694
+ while current_turn < max_turns:
695
+ current_turn += 1
696
+ full_history = self._prepare_messages(messages if current_turn == 1 else [])
697
+
698
+ response = await runner(
699
+ self.provider.ainvoke, full_history, self._get_run_config(config)
700
+ )
701
+
702
+ self.memory.add_message(response.to_message())
703
+
704
+ # Log the interaction if logger is active
705
+ if self.logger_id:
706
+ # Extract token usage info if available
707
+ token_usage = None
708
+ if hasattr(response, "usage"):
709
+ token_usage = {
710
+ "prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
711
+ "completion_tokens": getattr(
712
+ response.usage, "completion_tokens", 0
713
+ ),
714
+ "total_tokens": getattr(response.usage, "total_tokens", 0),
715
+ }
716
+
717
+ # Log the LLM response
718
+ agent_logger.log_llm_response(self.logger_id, str(response.content))
719
+
720
+ # Log token usage and calculate cost
721
+ if token_usage:
722
+ agent_logger.log_token_usage(
723
+ self.logger_id,
724
+ token_usage["prompt_tokens"],
725
+ token_usage["completion_tokens"],
726
+ token_usage["total_tokens"],
727
+ )
728
+
729
+ # Calculate and log cost
730
+ agent_logger.calculate_cost(
731
+ self.logger_id,
732
+ token_usage["prompt_tokens"],
733
+ token_usage["completion_tokens"],
734
+ getattr(self.provider, "model", ""),
735
+ )
736
+
737
+ # Save session update
738
+ if self.use_persistence:
739
+ self._save_session()
740
+
741
+ if response.has_tool_calls:
742
+ tool_result = await self._aprocess_tool_calls(response)
743
+
744
+ # Check if user declined tool execution
745
+ if isinstance(tool_result, tuple) and tool_result[0] is False:
746
+ # User declined - add their input as a new message and continue
747
+ user_elaboration = tool_result[1]
748
+ if user_elaboration:
749
+ self.memory.add_message(Message.user(user_elaboration))
750
+ else:
751
+ final_response = response
752
+ break
753
+ continue
754
+
755
+ tool_msgs = tool_result
756
+
757
+ for tm in tool_msgs:
758
+ self.memory.add_message(tm)
759
+
760
+ # Log tool usage
761
+ if self.logger_id:
762
+ # Extract tool name and arguments
763
+ tool_name = "unknown"
764
+ arguments = {}
765
+ if hasattr(tm, "tool_call_id"):
766
+ # This is a tool message, need to find the tool name
767
+ # We'll check the original response to find the tool
768
+ for tool_call in response.tool_calls:
769
+ fn_info = tool_call.get("function", {})
770
+ if fn_info.get("name"):
771
+ tool_name = fn_info.get("name", "unknown")
772
+ arguments = fn_info.get("arguments", {})
773
+ agent_logger.log_tool_usage(
774
+ self.logger_id, tool_name, arguments
775
+ )
776
+ break
777
+
778
+ # Save session update
779
+ if self.use_persistence:
780
+ self._save_session()
781
+ continue
782
+ else:
783
+ final_response = response
784
+ break
785
+
786
+ return final_response
787
+
788
+ async def astream(
789
+ self, messages: Union[str, list[Message]], config: Optional[ModelConfig] = None
790
+ ) -> AsyncIterator[StreamChunk]:
791
+ """Asynchronous streaming with memory aggregation."""
792
+ # Prepare messages
793
+ full_history = self._prepare_messages(messages)
794
+
795
+ # Determine config
796
+ run_config = self._get_run_config(config)
797
+
798
+ # Get stream iterator directly (cannot use tenacity on async generator creation easily)
799
+ stream_iterator = self.provider.astream(full_history, run_config)
800
+
801
+ aggregated_content = ""
802
+ aggregated_tool_calls = []
803
+
804
+ async for chunk in stream_iterator:
805
+ aggregated_content += chunk.content
806
+ if chunk.tool_calls:
807
+ # TODO: robust tool call aggregation if streaming partial JSON
808
+ # For now, assume provider yields complete tool calls in chunks or we just collect them
809
+ aggregated_tool_calls.extend(chunk.tool_calls)
810
+ yield chunk
811
+
812
+ # Create Message and add to memory
813
+ # Note: If no content and no tools, we don't add (or adds empty message)
814
+
815
+ # If we have tool calls, we might need to properly format them
816
+ final_msg = Message(
817
+ role="assistant",
818
+ content=aggregated_content,
819
+ tool_calls=aggregated_tool_calls if aggregated_tool_calls else None,
820
+ )
821
+
822
+ self.memory.add_message(final_msg)
823
+
824
+ if self.use_persistence:
825
+ self._save_session()
826
+
827
+ # -------------------------------------------------------------------------
828
+ # Serialization Methods
829
+ # -------------------------------------------------------------------------
830
+
831
+ def to_yaml(self, path: Union[str, Path]) -> None:
832
+ """
833
+ Serialize agent configuration to YAML.
834
+
835
+ Args:
836
+ path: File path to save YAML.
837
+ """
838
+ system_prompt_str = (
839
+ self.system_prompt.resolve_prompt()
840
+ if isinstance(self.system_prompt, PromptBase)
841
+ else str(self.system_prompt)
842
+ )
843
+ data = {
844
+ "name": self.name,
845
+ "system_prompt": system_prompt_str,
846
+ "retry_attempts": self.retry_attempts,
847
+ "provider": {
848
+ "model": self.provider.model,
849
+ # Add other provider settings if possible
850
+ },
851
+ "tools": self._tool_registry.names,
852
+ # Memory state could be saved here too, but usually configured separately
853
+ }
854
+
855
+ path_obj = Path(path)
856
+ path_obj.parent.mkdir(parents=True, exist_ok=True)
857
+
858
+ with open(path_obj, "w", encoding="utf-8") as f:
859
+ yaml.dump(data, f, sort_keys=False, default_flow_style=False)
860
+
861
+ @classmethod
862
+ def from_yaml(
863
+ cls, path: Union[str, Path], tool_registry: Optional[ToolRegistry] = None
864
+ ) -> "BaseAgent":
865
+ """
866
+ Load agent from YAML configuration.
867
+
868
+ Args:
869
+ path: Path to YAML file.
870
+ tool_registry: Registry containing *available* tools to re-hydrate the agent.
871
+ The agent's tools will be selected from this registry based on names in YAML.
872
+
873
+ Returns:
874
+ Instantiated BaseAgent.
875
+ """
876
+ path_obj = Path(path)
877
+ with open(path_obj, "r", encoding="utf-8") as f:
878
+ data = yaml.safe_load(f)
879
+
880
+ name = data.get("name", "unnamed_agent")
881
+ system_prompt = data.get("system_prompt", "")
882
+ retry_attempts = data.get("retry_attempts", 3)
883
+ provider_config = data.get("provider", {})
884
+ model_name = provider_config.get("model", "qwen3-coder:480b-cloud")
885
+
886
+ # Handle persistence parameter
887
+ # If persistence key exists and is True, use_persistence is True
888
+ # If persistence key exists and is False, use_persistence is False
889
+ # If persistence key doesn't exist, use_persistence defaults to True
890
+ use_persistence = data.get("persistence", True)
891
+
892
+ # Reconstruct tools
893
+ tools = []
894
+ tool_names = data.get("tools", [])
895
+
896
+ # Use provided registry or fallback to default
897
+ registry = tool_registry
898
+ if registry is None:
899
+ # Lazy import to avoid circular dependencies if any
900
+ try:
901
+ from kader.tools import get_default_registry
902
+
903
+ registry = get_default_registry()
904
+ except ImportError:
905
+ pass
906
+
907
+ if tool_names and registry:
908
+ for t_name in tool_names:
909
+ t = registry.get(t_name)
910
+ if t:
911
+ tools.append(t)
912
+
913
+ return cls(
914
+ name=name,
915
+ system_prompt=system_prompt,
916
+ tools=tools,
917
+ retry_attempts=retry_attempts,
918
+ model_name=model_name,
919
+ use_persistence=use_persistence,
920
+ )