kader 0.1.6__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cli/app.py CHANGED
@@ -1,7 +1,9 @@
1
1
  """Kader CLI - Modern Vibe Coding CLI with Textual."""
2
2
 
3
3
  import asyncio
4
+ import atexit
4
5
  import threading
6
+ from concurrent.futures import ThreadPoolExecutor
5
7
  from importlib.metadata import version as get_version
6
8
  from pathlib import Path
7
9
  from typing import Optional
@@ -18,14 +20,13 @@ from textual.widgets import (
18
20
  Tree,
19
21
  )
20
22
 
21
- from kader.agent.agents import ReActAgent
22
23
  from kader.memory import (
23
24
  FileSessionManager,
24
25
  MemoryConfig,
25
- SlidingWindowConversationManager,
26
26
  )
27
- from kader.tools import get_default_registry
27
+ from kader.workflows import PlannerExecutorWorkflow
28
28
 
29
+ from .llm_factory import LLMProviderFactory
29
30
  from .utils import (
30
31
  DEFAULT_MODEL,
31
32
  HELP_TEXT,
@@ -103,22 +104,83 @@ class KaderApp(App):
103
104
  self._model_selector: Optional[ModelSelector] = None
104
105
  self._update_info: Optional[str] = None # Latest version if update available
105
106
 
106
- self._agent = self._create_agent(self._current_model)
107
+ # Dedicated thread pool for agent invocation (isolated from default pool)
108
+ self._agent_executor = ThreadPoolExecutor(
109
+ max_workers=2, thread_name_prefix="kader_agent"
110
+ )
111
+ # Ensure executor is properly shut down on exit
112
+ atexit.register(self._agent_executor.shutdown, wait=False)
113
+
114
+ self._workflow = self._create_workflow(self._current_model)
107
115
 
108
- def _create_agent(self, model_name: str) -> ReActAgent:
109
- """Create a new ReActAgent with the specified model."""
110
- registry = get_default_registry()
111
- memory = SlidingWindowConversationManager(window_size=10)
112
- return ReActAgent(
116
+ def _create_workflow(self, model_name: str) -> PlannerExecutorWorkflow:
117
+ """Create a new PlannerExecutorWorkflow with the specified model."""
118
+ # Create provider using factory (supports provider:model format)
119
+ provider = LLMProviderFactory.create_provider(model_name)
120
+
121
+ return PlannerExecutorWorkflow(
113
122
  name="kader_cli",
114
- tools=registry,
115
- memory=memory,
116
- model_name=model_name,
117
- use_persistence=True,
123
+ provider=provider,
124
+ model_name=model_name, # Keep for reference
118
125
  interrupt_before_tool=True,
119
126
  tool_confirmation_callback=self._tool_confirmation_callback,
127
+ direct_execution_callback=self._direct_execution_callback,
128
+ tool_execution_result_callback=self._tool_execution_result_callback,
129
+ use_persistence=True,
130
+ executor_names=["executor"],
131
+ )
132
+
133
+ def _direct_execution_callback(self, message: str, tool_name: str) -> None:
134
+ """
135
+ Callback for direct execution tools - called from agent thread.
136
+
137
+ Shows a message in the conversation view without blocking for confirmation.
138
+ """
139
+ # Schedule message display on main thread
140
+ self.call_from_thread(self._show_direct_execution_message, message, tool_name)
141
+
142
+ def _show_direct_execution_message(self, message: str, tool_name: str) -> None:
143
+ """Show a direct execution message in the conversation view."""
144
+ try:
145
+ conversation = self.query_one("#conversation-view", ConversationView)
146
+ # User-friendly message showing the tool is executing
147
+ friendly_message = f"[>] Executing {tool_name}..."
148
+ conversation.add_message(friendly_message, "assistant")
149
+ conversation.scroll_end()
150
+ except Exception:
151
+ pass
152
+
153
+ def _tool_execution_result_callback(
154
+ self, tool_name: str, success: bool, result: str
155
+ ) -> None:
156
+ """
157
+ Callback for tool execution results - called from agent thread.
158
+
159
+ Updates the conversation view with the execution result.
160
+ """
161
+ # Schedule result display on main thread
162
+ self.call_from_thread(
163
+ self._show_tool_execution_result, tool_name, success, result
120
164
  )
121
165
 
166
+ def _show_tool_execution_result(
167
+ self, tool_name: str, success: bool, result: str
168
+ ) -> None:
169
+ """Show the tool execution result in the conversation view."""
170
+ try:
171
+ conversation = self.query_one("#conversation-view", ConversationView)
172
+ if success:
173
+ # User-friendly success message
174
+ friendly_message = f"(+) {tool_name} completed successfully"
175
+ else:
176
+ # User-friendly error message with truncated result
177
+ error_preview = result[:100] + "..." if len(result) > 100 else result
178
+ friendly_message = f"(-) {tool_name} failed: {error_preview}"
179
+ conversation.add_message(friendly_message, "assistant")
180
+ conversation.scroll_end()
181
+ except Exception:
182
+ pass
183
+
122
184
  def _tool_confirmation_callback(self, message: str) -> tuple[bool, Optional[str]]:
123
185
  """
124
186
  Callback for tool confirmation - called from agent thread.
@@ -135,7 +197,10 @@ class KaderApp(App):
135
197
 
136
198
  # Wait for user response (blocking in agent thread)
137
199
  # This is safe because we're in a background thread
138
- self._confirmation_event.wait()
200
+ # Timeout after 5 minutes to prevent indefinite blocking
201
+ if not self._confirmation_event.wait(timeout=300):
202
+ # Timeout occurred - decline tool execution gracefully
203
+ return (False, "Tool confirmation timed out after 5 minutes")
139
204
 
140
205
  # Return the result
141
206
  return self._confirmation_result
@@ -183,7 +248,8 @@ class KaderApp(App):
183
248
  if event.confirmed:
184
249
  if tool_message:
185
250
  conversation.add_message(tool_message, "assistant")
186
- conversation.add_message("(+) Executing tool...", "assistant")
251
+ # Show executing message - will be updated by result callback
252
+ conversation.add_message("[>] Executing tool...", "assistant")
187
253
  # Restart spinner
188
254
  try:
189
255
  spinner = self.query_one(LoadingSpinner)
@@ -207,13 +273,12 @@ class KaderApp(App):
207
273
 
208
274
  async def _show_model_selector(self, conversation: ConversationView) -> None:
209
275
  """Show the model selector widget."""
210
- from kader.providers import OllamaProvider
211
-
212
276
  try:
213
- models = OllamaProvider.get_supported_models()
277
+ # Get models from all available providers
278
+ models = LLMProviderFactory.get_flat_model_list()
214
279
  if not models:
215
280
  conversation.add_message(
216
- "## Models (^^)\n\n*No models found. Is Ollama running?*",
281
+ "## Models (^^)\n\n*No models found. Check provider configurations.*",
217
282
  "assistant",
218
283
  )
219
284
  return
@@ -249,7 +314,7 @@ class KaderApp(App):
249
314
  # Update model and recreate agent
250
315
  old_model = self._current_model
251
316
  self._current_model = event.model
252
- self._agent = self._create_agent(self._current_model)
317
+ self._workflow = self._create_workflow(self._current_model)
253
318
 
254
319
  conversation.add_message(
255
320
  f"(+) Model changed from `{old_model}` to `{self._current_model}`",
@@ -431,8 +496,8 @@ Please resize your terminal."""
431
496
  await self._show_model_selector(conversation)
432
497
  elif cmd == "/clear":
433
498
  conversation.clear_messages()
434
- self._agent.memory.clear()
435
- self._agent.provider.reset_tracking() # Reset usage/cost tracking
499
+ self._workflow.planner.memory.clear()
500
+ self._workflow.planner.provider.reset_tracking() # Reset usage/cost tracking
436
501
  self._current_session_id = None
437
502
  self.notify("Conversation cleared!", severity="information")
438
503
  elif cmd == "/save":
@@ -462,7 +527,7 @@ Please resize your terminal."""
462
527
  )
463
528
 
464
529
  async def _handle_chat(self, message: str) -> None:
465
- """Handle regular chat messages with ReActAgent."""
530
+ """Handle regular chat messages with PlannerExecutorWorkflow."""
466
531
  if self._is_processing:
467
532
  self.notify("Please wait for the current response...", severity="warning")
468
533
  return
@@ -490,20 +555,25 @@ Please resize your terminal."""
490
555
  spinner = self.query_one(LoadingSpinner)
491
556
 
492
557
  try:
493
- # Run the agent invoke in a thread
558
+ # Run the workflow in a dedicated thread pool
494
559
  loop = asyncio.get_event_loop()
495
560
  response = await loop.run_in_executor(
496
- None, lambda: self._agent.invoke(message)
561
+ self._agent_executor, lambda: self._workflow.run(message)
497
562
  )
498
563
 
499
564
  # Hide spinner and show response (this runs on main thread via await)
500
565
  spinner.stop()
501
- if response and response.content:
502
- conversation.add_message(response.content, "assistant")
566
+ if response:
567
+ conversation.add_message(
568
+ response,
569
+ "assistant",
570
+ model_name=self._workflow.planner.provider.model,
571
+ usage_cost=self._workflow.planner.provider.total_cost.total_cost,
572
+ )
503
573
 
504
574
  except Exception as e:
505
575
  spinner.stop()
506
- error_msg = f"(-) **Error:** {str(e)}\n\nMake sure Ollama is running and the model `{self._current_model}` is available."
576
+ error_msg = f"(-) **Error:** {str(e)}\n\nMake sure the provider for `{self._current_model}` is configured and available."
507
577
  conversation.add_message(error_msg, "assistant")
508
578
  self.notify(f"Error: {e}", severity="error")
509
579
 
@@ -516,7 +586,7 @@ Please resize your terminal."""
516
586
  """Clear the conversation (Ctrl+L)."""
517
587
  conversation = self.query_one("#conversation-view", ConversationView)
518
588
  conversation.clear_messages()
519
- self._agent.memory.clear()
589
+ self._workflow.planner.memory.clear()
520
590
  self.notify("Conversation cleared!", severity="information")
521
591
 
522
592
  def action_save_session(self) -> None:
@@ -548,8 +618,10 @@ Please resize your terminal."""
548
618
  session = self._session_manager.create_session("kader_cli")
549
619
  self._current_session_id = session.session_id
550
620
 
551
- # Get messages from agent memory and save
552
- messages = [msg.message for msg in self._agent.memory.get_messages()]
621
+ # Get messages from planner memory and save
622
+ messages = [
623
+ msg.message for msg in self._workflow.planner.memory.get_messages()
624
+ ]
553
625
  self._session_manager.save_conversation(self._current_session_id, messages)
554
626
 
555
627
  conversation.add_message(
@@ -580,11 +652,11 @@ Please resize your terminal."""
580
652
 
581
653
  # Clear current state
582
654
  conversation.clear_messages()
583
- self._agent.memory.clear()
655
+ self._workflow.planner.memory.clear()
584
656
 
585
657
  # Add loaded messages to memory and UI
586
658
  for msg in messages:
587
- self._agent.memory.add_message(msg)
659
+ self._workflow.planner.memory.add_message(msg)
588
660
  role = msg.get("role", "user")
589
661
  content = msg.get("content", "")
590
662
  if role in ["user", "assistant"] and content:
@@ -633,9 +705,9 @@ Please resize your terminal."""
633
705
  """Display LLM usage costs."""
634
706
  try:
635
707
  # Get cost and usage from the provider
636
- cost = self._agent.provider.total_cost
637
- usage = self._agent.provider.total_usage
638
- model = self._agent.provider.model
708
+ cost = self._workflow.planner.provider.total_cost
709
+ usage = self._workflow.planner.provider.total_usage
710
+ model = self._workflow.planner.provider.model
639
711
 
640
712
  lines = [
641
713
  "## Usage Costs ($)\n",
cli/app.tcss CHANGED
@@ -132,6 +132,26 @@ ConversationView {
132
132
  scrollbar-size: 1 1;
133
133
  }
134
134
 
135
+ .message-footer {
136
+ height: auto;
137
+ margin-top: 0;
138
+ padding: 0 1;
139
+ border-top: none;
140
+ }
141
+
142
+ .footer-left {
143
+ color: $secondary;
144
+ text-style: italic;
145
+ width: 1fr;
146
+ }
147
+
148
+ .footer-right {
149
+ color: $success;
150
+ text-style: bold;
151
+ text-align: right;
152
+ width: auto;
153
+ }
154
+
135
155
  /* ===== Welcome Message ===== */
136
156
 
137
157
  #welcome {
cli/llm_factory.py ADDED
@@ -0,0 +1,165 @@
1
+ """LLM Provider Factory for Kader CLI.
2
+
3
+ Factory pattern implementation for creating LLM provider instances
4
+ with automatic provider detection based on model name format.
5
+ """
6
+
7
+ from typing import Optional
8
+
9
+ from kader.providers import GoogleProvider, OllamaProvider
10
+ from kader.providers.base import BaseLLMProvider, ModelConfig
11
+
12
+
13
+ class LLMProviderFactory:
14
+ """
15
+ Factory for creating LLM provider instances.
16
+
17
+ Supports multiple providers with automatic detection based on model name format.
18
+ Model names can be specified as:
19
+ - "provider:model" (e.g., "google:gemini-2.5-flash", "ollama:kimi-k2.5:cloud")
20
+ - "model" (defaults to Ollama for backward compatibility)
21
+
22
+ Example:
23
+ factory = LLMProviderFactory()
24
+ provider = factory.create_provider("google:gemini-2.5-flash")
25
+
26
+ # Or with default provider (Ollama)
27
+ provider = factory.create_provider("kimi-k2.5:cloud")
28
+ """
29
+
30
+ # Registered provider classes
31
+ PROVIDERS: dict[str, type[BaseLLMProvider]] = {
32
+ "ollama": OllamaProvider,
33
+ "google": GoogleProvider,
34
+ }
35
+
36
+ # Default provider when no prefix is specified
37
+ DEFAULT_PROVIDER = "ollama"
38
+
39
+ @classmethod
40
+ def parse_model_name(cls, model_string: str) -> tuple[str, str]:
41
+ """
42
+ Parse model string to extract provider and model name.
43
+
44
+ Args:
45
+ model_string: Model string in format "provider:model" or just "model"
46
+
47
+ Returns:
48
+ Tuple of (provider_name, model_name)
49
+ """
50
+ # Check if the string starts with a known provider prefix
51
+ for provider_name in cls.PROVIDERS.keys():
52
+ prefix = f"{provider_name}:"
53
+ if model_string.lower().startswith(prefix):
54
+ return provider_name, model_string[len(prefix) :]
55
+
56
+ # No known provider prefix found, use default
57
+ return cls.DEFAULT_PROVIDER, model_string
58
+
59
+ @classmethod
60
+ def create_provider(
61
+ cls,
62
+ model_string: str,
63
+ config: Optional[ModelConfig] = None,
64
+ ) -> BaseLLMProvider:
65
+ """
66
+ Create an LLM provider instance.
67
+
68
+ Args:
69
+ model_string: Model identifier (e.g., "google:gemini-2.5-flash" or "kimi-k2.5:cloud")
70
+ config: Optional model configuration
71
+
72
+ Returns:
73
+ Configured provider instance
74
+
75
+ Raises:
76
+ ValueError: If provider is not supported
77
+ """
78
+ provider_name, model_name = cls.parse_model_name(model_string)
79
+
80
+ provider_class = cls.PROVIDERS.get(provider_name)
81
+ if not provider_class:
82
+ supported = ", ".join(cls.PROVIDERS.keys())
83
+ raise ValueError(
84
+ f"Unknown provider: {provider_name}. Supported: {supported}"
85
+ )
86
+
87
+ return provider_class(model=model_name, default_config=config)
88
+
89
+ @classmethod
90
+ def get_all_models(cls) -> dict[str, list[str]]:
91
+ """
92
+ Get all available models from all registered providers.
93
+
94
+ Returns:
95
+ Dictionary mapping provider names to their available models
96
+ (with provider prefix included in model names)
97
+ """
98
+ models: dict[str, list[str]] = {}
99
+
100
+ # Get Ollama models
101
+ try:
102
+ ollama_models = OllamaProvider.get_supported_models()
103
+ models["ollama"] = [f"ollama:{m}" for m in ollama_models]
104
+ except Exception:
105
+ models["ollama"] = []
106
+
107
+ # Get Google models
108
+ try:
109
+ google_models = GoogleProvider.get_supported_models()
110
+ models["google"] = [f"google:{m}" for m in google_models]
111
+ except Exception:
112
+ models["google"] = []
113
+
114
+ return models
115
+
116
+ @classmethod
117
+ def get_flat_model_list(cls) -> list[str]:
118
+ """
119
+ Get a flattened list of all available models with provider prefixes.
120
+
121
+ Returns:
122
+ List of model strings in "provider:model" format
123
+ """
124
+ all_models = cls.get_all_models()
125
+ flat_list: list[str] = []
126
+ for models in all_models.values():
127
+ flat_list.extend(models)
128
+ return flat_list
129
+
130
+ @classmethod
131
+ def is_provider_available(cls, provider_name: str) -> bool:
132
+ """
133
+ Check if a provider is available and configured.
134
+
135
+ Args:
136
+ provider_name: Name of the provider to check
137
+
138
+ Returns:
139
+ True if provider is available and has models, False otherwise
140
+ """
141
+ provider_name = provider_name.lower()
142
+ if provider_name not in cls.PROVIDERS:
143
+ return False
144
+
145
+ # Try to get models to verify provider is working
146
+ try:
147
+ provider_class = cls.PROVIDERS[provider_name]
148
+ models = provider_class.get_supported_models()
149
+ return len(models) > 0
150
+ except Exception:
151
+ return False
152
+
153
+ @classmethod
154
+ def get_provider_name(cls, model_string: str) -> str:
155
+ """
156
+ Get the provider name for a given model string.
157
+
158
+ Args:
159
+ model_string: Model string in format "provider:model" or just "model"
160
+
161
+ Returns:
162
+ Provider name (e.g., "ollama", "google")
163
+ """
164
+ provider_name, _ = cls.parse_model_name(model_string)
165
+ return provider_name
cli/utils.py CHANGED
@@ -1,9 +1,9 @@
1
1
  """Utility constants and helpers for Kader CLI."""
2
2
 
3
- from kader.providers import OllamaProvider
3
+ from .llm_factory import LLMProviderFactory
4
4
 
5
- # Default model
6
- DEFAULT_MODEL = "qwen3-coder:480b-cloud"
5
+ # Default model (with provider prefix for clarity)
6
+ DEFAULT_MODEL = "ollama:kimi-k2.5:cloud"
7
7
 
8
8
  HELP_TEXT = """## Kader CLI Commands
9
9
 
@@ -40,24 +40,32 @@ HELP_TEXT = """## Kader CLI Commands
40
40
  ### Tips:
41
41
  - Type any question to chat with the AI
42
42
  - Use **Tab** to navigate between panels
43
+ - Model format: `provider:model` (e.g., `google:gemini-2.5-flash`)
43
44
  """
44
45
 
45
46
 
46
47
  def get_models_text() -> str:
47
- """Get formatted text of available Ollama models."""
48
+ """Get formatted text of available models from all providers."""
48
49
  try:
49
- models = OllamaProvider.get_supported_models()
50
- if not models:
51
- return "## Available Models (^^)\n\n*No models found. Is Ollama running?*"
50
+ all_models = LLMProviderFactory.get_all_models()
51
+ flat_list = LLMProviderFactory.get_flat_model_list()
52
+
53
+ if not flat_list:
54
+ return "## Available Models (^^)\n\n*No models found. Check provider configurations.*"
52
55
 
53
56
  lines = [
54
57
  "## Available Models (^^)\n",
55
- "| Model | Status |",
56
- "|-------|--------|",
58
+ "| Provider | Model | Status |",
59
+ "|----------|-------|--------|",
57
60
  ]
58
- for model in models:
59
- lines.append(f"| {model} | (+) Available |")
61
+ for provider_name, provider_models in all_models.items():
62
+ for model in provider_models:
63
+ lines.append(f"| {provider_name.title()} | `{model}` | (+) Available |")
64
+
60
65
  lines.append(f"\n*Currently using: **{DEFAULT_MODEL}***")
66
+ lines.append(
67
+ "\n> (!) Tip: Use `provider:model` format (e.g., `google:gemini-2.5-flash`)"
68
+ )
61
69
  return "\n".join(lines)
62
70
  except Exception as e:
63
71
  return f"## Available Models (^^)\n\n*Error fetching models: {e}*"
@@ -1,23 +1,43 @@
1
1
  """Conversation display widget for Kader CLI."""
2
2
 
3
3
  from textual.app import ComposeResult
4
- from textual.containers import VerticalScroll
4
+ from textual.containers import Horizontal, VerticalScroll
5
5
  from textual.widgets import Markdown, Static
6
6
 
7
7
 
8
8
  class Message(Static):
9
9
  """A single message in the conversation."""
10
10
 
11
- def __init__(self, content: str, role: str = "user") -> None:
11
+ def __init__(
12
+ self,
13
+ content: str,
14
+ role: str = "user",
15
+ model_name: str | None = None,
16
+ usage_cost: float | None = None,
17
+ ) -> None:
12
18
  super().__init__()
13
19
  self.content = content
14
20
  self.role = role
21
+ self.model_name = model_name
22
+ self.usage_cost = usage_cost
15
23
  self.add_class(f"message-{role}")
16
24
 
17
25
  def compose(self) -> ComposeResult:
18
26
  prefix = "(**) **You:**" if self.role == "user" else "(^^) **Kader:**"
19
27
  yield Markdown(f"{prefix}\n\n{self.content}")
20
28
 
29
+ if self.role == "assistant" and (
30
+ self.model_name or self.usage_cost is not None
31
+ ):
32
+ with Horizontal(classes="message-footer"):
33
+ model_label = f"[*] {self.model_name}" if self.model_name else ""
34
+ yield Static(model_label, classes="footer-left")
35
+
36
+ usage_label = (
37
+ f"($) {self.usage_cost:.6f}" if self.usage_cost is not None else ""
38
+ )
39
+ yield Static(usage_label, classes="footer-right")
40
+
21
41
 
22
42
  class ConversationView(VerticalScroll):
23
43
  """Scrollable conversation history with markdown rendering."""
@@ -41,11 +61,37 @@ class ConversationView(VerticalScroll):
41
61
  background: $surface-darken-1;
42
62
  border-left: thick $success;
43
63
  }
64
+
65
+ .message-footer {
66
+ height: auto;
67
+ margin-top: 0;
68
+ padding: 0 1;
69
+ border-top: none;
70
+ }
71
+
72
+ .footer-left {
73
+ color: $secondary;
74
+ text-style: italic;
75
+ width: 1fr;
76
+ }
77
+
78
+ .footer-right {
79
+ color: $success;
80
+ text-style: bold;
81
+ text-align: right;
82
+ width: auto;
83
+ }
44
84
  """
45
85
 
46
- def add_message(self, content: str, role: str = "user") -> None:
86
+ def add_message(
87
+ self,
88
+ content: str,
89
+ role: str = "user",
90
+ model_name: str | None = None,
91
+ usage_cost: float | None = None,
92
+ ) -> None:
47
93
  """Add a message to the conversation."""
48
- message = Message(content, role)
94
+ message = Message(content, role, model_name, usage_cost)
49
95
  self.mount(message)
50
96
  self.scroll_end(animate=True)
51
97
 
kader/__init__.py CHANGED
@@ -8,6 +8,7 @@ creating the .kader directory in the user's home directory.
8
8
  from .config import ENV_FILE_PATH, KADER_DIR, initialize_kader_config
9
9
  from .providers import * # noqa: F401, F403
10
10
  from .tools import * # noqa: F401, F403
11
+ from .utils import Checkpointer
11
12
 
12
13
  # Initialize the configuration when the module is imported
13
14
  initialize_kader_config()
@@ -18,5 +19,6 @@ __all__ = [
18
19
  "KADER_DIR",
19
20
  "ENV_FILE_PATH",
20
21
  "initialize_kader_config",
22
+ "Checkpointer",
21
23
  # Export everything from providers and tools
22
24
  ]
kader/agent/agents.py CHANGED
@@ -31,6 +31,8 @@ class ReActAgent(BaseAgent):
31
31
  use_persistence: bool = False,
32
32
  interrupt_before_tool: bool = True,
33
33
  tool_confirmation_callback: Optional[callable] = None,
34
+ direct_execution_callback: Optional[callable] = None,
35
+ tool_execution_result_callback: Optional[callable] = None,
34
36
  ) -> None:
35
37
  # Resolve tools for prompt context if necessary
36
38
  # The base agent handles tool registration, but for the prompt template
@@ -67,6 +69,8 @@ class ReActAgent(BaseAgent):
67
69
  use_persistence=use_persistence,
68
70
  interrupt_before_tool=interrupt_before_tool,
69
71
  tool_confirmation_callback=tool_confirmation_callback,
72
+ direct_execution_callback=direct_execution_callback,
73
+ tool_execution_result_callback=tool_execution_result_callback,
70
74
  )
71
75
 
72
76
 
@@ -90,6 +94,8 @@ class PlanningAgent(BaseAgent):
90
94
  use_persistence: bool = False,
91
95
  interrupt_before_tool: bool = True,
92
96
  tool_confirmation_callback: Optional[callable] = None,
97
+ direct_execution_callback: Optional[callable] = None,
98
+ tool_execution_result_callback: Optional[callable] = None,
93
99
  ) -> None:
94
100
  # Ensure TodoTool is available
95
101
  _todo_tool = TodoTool()
@@ -123,4 +129,6 @@ class PlanningAgent(BaseAgent):
123
129
  use_persistence=use_persistence,
124
130
  interrupt_before_tool=interrupt_before_tool,
125
131
  tool_confirmation_callback=tool_confirmation_callback,
132
+ direct_execution_callback=direct_execution_callback,
133
+ tool_execution_result_callback=tool_execution_result_callback,
126
134
  )