tunacode-cli 0.0.50__py3-none-any.whl → 0.0.53__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tunacode-cli might be problematic. Click here for more details.

Files changed (87) hide show
  1. tunacode/cli/commands/base.py +2 -2
  2. tunacode/cli/commands/implementations/__init__.py +7 -1
  3. tunacode/cli/commands/implementations/conversation.py +1 -1
  4. tunacode/cli/commands/implementations/debug.py +1 -1
  5. tunacode/cli/commands/implementations/development.py +4 -1
  6. tunacode/cli/commands/implementations/template.py +132 -0
  7. tunacode/cli/commands/registry.py +28 -1
  8. tunacode/cli/commands/template_shortcut.py +93 -0
  9. tunacode/cli/main.py +6 -0
  10. tunacode/cli/repl.py +29 -174
  11. tunacode/cli/repl_components/__init__.py +10 -0
  12. tunacode/cli/repl_components/command_parser.py +34 -0
  13. tunacode/cli/repl_components/error_recovery.py +88 -0
  14. tunacode/cli/repl_components/output_display.py +33 -0
  15. tunacode/cli/repl_components/tool_executor.py +84 -0
  16. tunacode/configuration/defaults.py +2 -2
  17. tunacode/configuration/settings.py +11 -14
  18. tunacode/constants.py +57 -23
  19. tunacode/context.py +0 -14
  20. tunacode/core/agents/agent_components/__init__.py +27 -0
  21. tunacode/core/agents/agent_components/agent_config.py +109 -0
  22. tunacode/core/agents/agent_components/json_tool_parser.py +109 -0
  23. tunacode/core/agents/agent_components/message_handler.py +100 -0
  24. tunacode/core/agents/agent_components/node_processor.py +480 -0
  25. tunacode/core/agents/agent_components/response_state.py +13 -0
  26. tunacode/core/agents/agent_components/result_wrapper.py +50 -0
  27. tunacode/core/agents/agent_components/task_completion.py +28 -0
  28. tunacode/core/agents/agent_components/tool_buffer.py +24 -0
  29. tunacode/core/agents/agent_components/tool_executor.py +49 -0
  30. tunacode/core/agents/main.py +421 -778
  31. tunacode/core/agents/utils.py +42 -2
  32. tunacode/core/background/manager.py +3 -3
  33. tunacode/core/logging/__init__.py +4 -3
  34. tunacode/core/logging/config.py +29 -16
  35. tunacode/core/logging/formatters.py +1 -1
  36. tunacode/core/logging/handlers.py +41 -7
  37. tunacode/core/setup/__init__.py +2 -0
  38. tunacode/core/setup/agent_setup.py +2 -2
  39. tunacode/core/setup/base.py +2 -2
  40. tunacode/core/setup/config_setup.py +10 -6
  41. tunacode/core/setup/git_safety_setup.py +13 -2
  42. tunacode/core/setup/template_setup.py +75 -0
  43. tunacode/core/state.py +13 -2
  44. tunacode/core/token_usage/api_response_parser.py +6 -2
  45. tunacode/core/token_usage/usage_tracker.py +37 -7
  46. tunacode/core/tool_handler.py +24 -1
  47. tunacode/prompts/system.md +289 -4
  48. tunacode/setup.py +2 -0
  49. tunacode/templates/__init__.py +9 -0
  50. tunacode/templates/loader.py +210 -0
  51. tunacode/tools/glob.py +3 -3
  52. tunacode/tools/grep.py +26 -276
  53. tunacode/tools/grep_components/__init__.py +9 -0
  54. tunacode/tools/grep_components/file_filter.py +93 -0
  55. tunacode/tools/grep_components/pattern_matcher.py +152 -0
  56. tunacode/tools/grep_components/result_formatter.py +45 -0
  57. tunacode/tools/grep_components/search_result.py +35 -0
  58. tunacode/tools/todo.py +27 -21
  59. tunacode/types.py +19 -4
  60. tunacode/ui/completers.py +6 -1
  61. tunacode/ui/decorators.py +2 -2
  62. tunacode/ui/keybindings.py +1 -1
  63. tunacode/ui/panels.py +13 -5
  64. tunacode/ui/prompt_manager.py +1 -1
  65. tunacode/ui/tool_ui.py +8 -2
  66. tunacode/utils/bm25.py +4 -4
  67. tunacode/utils/file_utils.py +2 -2
  68. tunacode/utils/message_utils.py +3 -1
  69. tunacode/utils/system.py +0 -4
  70. tunacode/utils/text_utils.py +1 -1
  71. tunacode/utils/token_counter.py +2 -2
  72. {tunacode_cli-0.0.50.dist-info → tunacode_cli-0.0.53.dist-info}/METADATA +146 -1
  73. tunacode_cli-0.0.53.dist-info/RECORD +123 -0
  74. {tunacode_cli-0.0.50.dist-info → tunacode_cli-0.0.53.dist-info}/top_level.txt +0 -1
  75. api/auth.py +0 -13
  76. api/users.py +0 -8
  77. tunacode/core/recursive/__init__.py +0 -18
  78. tunacode/core/recursive/aggregator.py +0 -467
  79. tunacode/core/recursive/budget.py +0 -414
  80. tunacode/core/recursive/decomposer.py +0 -398
  81. tunacode/core/recursive/executor.py +0 -470
  82. tunacode/core/recursive/hierarchy.py +0 -488
  83. tunacode/ui/recursive_progress.py +0 -380
  84. tunacode_cli-0.0.50.dist-info/RECORD +0 -107
  85. {tunacode_cli-0.0.50.dist-info → tunacode_cli-0.0.53.dist-info}/WHEEL +0 -0
  86. {tunacode_cli-0.0.50.dist-info → tunacode_cli-0.0.53.dist-info}/entry_points.txt +0 -0
  87. {tunacode_cli-0.0.50.dist-info → tunacode_cli-0.0.53.dist-info}/licenses/LICENSE +0 -0
@@ -35,8 +35,48 @@ def get_agent_tool():
35
35
 
36
36
 
37
37
  def get_model_messages():
38
+ """
39
+ Safely retrieve message-related classes from pydantic_ai.
40
+
41
+ If the running environment (e.g. our test stubs) does not define
42
+ SystemPromptPart we create a minimal placeholder so that the rest of the
43
+ code can continue to work without depending on the real implementation.
44
+ """
38
45
  messages = importlib.import_module("pydantic_ai.messages")
39
- return messages.ModelRequest, messages.ToolReturnPart
46
+
47
+ # Create minimal fallbacks for missing message part classes
48
+ # SystemPromptPart
49
+ if not hasattr(messages, "SystemPromptPart"):
50
+
51
+ class SystemPromptPart: # type: ignore
52
+ def __init__(self, content: str = "", role: str = "system", part_kind: str = ""):
53
+ self.content = content
54
+ self.role = role
55
+ self.part_kind = part_kind
56
+
57
+ def __repr__(self) -> str: # pragma: no cover
58
+ return f"SystemPromptPart(content={self.content!r})"
59
+
60
+ SystemPromptPart.__module__ = messages.__name__
61
+ setattr(messages, "SystemPromptPart", SystemPromptPart)
62
+
63
+ # UserPromptPart
64
+ if not hasattr(messages, "UserPromptPart"):
65
+
66
+ class UserPromptPart: # type: ignore
67
+ def __init__(self, content: str = "", role: str = "user", part_kind: str = ""):
68
+ self.content = content
69
+ self.role = role
70
+ self.part_kind = part_kind
71
+
72
+ def __repr__(self) -> str: # pragma: no cover
73
+ return f"UserPromptPart(content={self.content!r})"
74
+
75
+ UserPromptPart.__module__ = messages.__name__
76
+ setattr(messages, "UserPromptPart", UserPromptPart)
77
+
78
+ # Finally, return the relevant classes so callers can use them directly
79
+ return messages.ModelRequest, messages.ToolReturnPart, messages.SystemPromptPart
40
80
 
41
81
 
42
82
  async def execute_tools_parallel(
@@ -335,7 +375,7 @@ def patch_tool_messages(
335
375
  for tool_call_id, tool_name in list(tool_calls.items()):
336
376
  if tool_call_id not in tool_returns and tool_call_id not in retry_prompts:
337
377
  # Import ModelRequest and ToolReturnPart lazily
338
- model_request_cls, tool_return_part_cls = get_model_messages()
378
+ model_request_cls, tool_return_part_cls, _ = get_model_messages()
339
379
  messages.append(
340
380
  model_request_cls(
341
381
  parts=[
@@ -5,7 +5,7 @@ from __future__ import annotations
5
5
  import asyncio
6
6
  import uuid
7
7
  from collections import defaultdict
8
- from typing import Awaitable, Callable, Dict, List
8
+ from typing import Any, Callable, Coroutine, Dict, List
9
9
 
10
10
 
11
11
  class BackgroundTaskManager:
@@ -15,9 +15,9 @@ class BackgroundTaskManager:
15
15
  self.tasks: Dict[str, asyncio.Task] = {}
16
16
  self.listeners: Dict[str, List[Callable[[asyncio.Task], None]]] = defaultdict(list)
17
17
 
18
- def spawn(self, coro: Awaitable, *, name: str | None = None) -> str:
18
+ def spawn(self, coro: Coroutine[Any, Any, Any], *, name: str | None = None) -> str:
19
19
  task_id = name or uuid.uuid4().hex[:8]
20
- task = asyncio.create_task(coro, name=task_id)
20
+ task: asyncio.Task = asyncio.create_task(coro, name=task_id)
21
21
  self.tasks[task_id] = task
22
22
  task.add_done_callback(self._notify)
23
23
  return task_id
@@ -1,21 +1,22 @@
1
1
  import logging
2
+ from typing import Any
2
3
 
3
4
  # Custom log level: THOUGHT
4
5
  THOUGHT = 25
5
6
  logging.addLevelName(THOUGHT, "THOUGHT")
6
7
 
7
8
 
8
- def thought(self, message, *args, **kwargs):
9
+ def thought(self: logging.Logger, message: str, *args: Any, **kwargs: Any) -> None:
9
10
  if self.isEnabledFor(THOUGHT):
10
11
  self._log(THOUGHT, message, args, **kwargs)
11
12
 
12
13
 
13
- logging.Logger.thought = thought
14
+ setattr(logging.Logger, "thought", thought)
14
15
 
15
16
 
16
17
  # RichHandler for UI output (stub, real implementation in handlers.py)
17
18
  class RichHandler(logging.Handler):
18
- def emit(self, record):
19
+ def emit(self, _record):
19
20
  # Actual implementation in handlers.py
20
21
  pass
21
22
 
@@ -1,19 +1,38 @@
1
1
  import logging
2
2
  import logging.config
3
- import os
4
-
5
- import yaml
6
3
 
7
4
  from tunacode.utils import user_configuration
8
5
 
9
- DEFAULT_CONFIG_PATH = os.path.join(
10
- os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "config", "logging.yaml"
11
- )
6
+ # Default logging configuration when none is provided
7
+ DEFAULT_LOGGING_CONFIG = {
8
+ "version": 1,
9
+ "disable_existing_loggers": False,
10
+ "formatters": {
11
+ "simple": {"format": "[%(levelname)s] %(message)s"},
12
+ "detailed": {"format": "[%(asctime)s] [%(levelname)s] [%(name)s:%(lineno)d] - %(message)s"},
13
+ },
14
+ "handlers": {
15
+ "file": {
16
+ "class": "logging.handlers.RotatingFileHandler",
17
+ "level": "DEBUG",
18
+ "formatter": "detailed",
19
+ "filename": "tunacode.log",
20
+ "maxBytes": 10485760, # 10MB
21
+ "backupCount": 5,
22
+ }
23
+ },
24
+ "root": {"level": "DEBUG", "handlers": ["file"]},
25
+ "loggers": {
26
+ "tunacode.ui": {"level": "INFO", "propagate": False},
27
+ "tunacode.tools": {"level": "DEBUG"},
28
+ "tunacode.core.agents": {"level": "DEBUG"},
29
+ },
30
+ }
12
31
 
13
32
 
14
33
  class LogConfig:
15
34
  @staticmethod
16
- def load(config_path=None):
35
+ def load(_config_path=None):
17
36
  """
18
37
  Load logging configuration based on user preferences.
19
38
  If logging is disabled (default), use minimal configuration.
@@ -43,15 +62,9 @@ class LogConfig:
43
62
  print(f"Failed to configure custom logging: {e}")
44
63
  logging.basicConfig(level=logging.INFO)
45
64
  else:
46
- # Use default configuration from YAML file
47
- path = config_path or DEFAULT_CONFIG_PATH
48
- if not os.path.exists(path):
49
- raise FileNotFoundError(f"Logging config file not found: {path}")
50
- with open(path, "r") as f:
51
- config = yaml.safe_load(f)
52
- logging_config = config.get("logging", config)
65
+ # Use default configuration
53
66
  try:
54
- logging.config.dictConfig(logging_config)
67
+ logging.config.dictConfig(DEFAULT_LOGGING_CONFIG)
55
68
  except Exception as e:
56
- print(f"Failed to configure logging: {e}")
69
+ print(f"Failed to configure default logging: {e}")
57
70
  logging.basicConfig(level=logging.INFO)
@@ -32,7 +32,7 @@ try:
32
32
  except ImportError:
33
33
  import json
34
34
 
35
- class JSONFormatter(logging.Formatter):
35
+ class JSONFormatter(logging.Formatter): # type: ignore[no-redef]
36
36
  """
37
37
  Fallback JSON formatter if pythonjsonlogger is not installed.
38
38
  """
@@ -26,11 +26,29 @@ class RichHandler(logging.Handler):
26
26
  super().__init__(level)
27
27
  self.console = Console()
28
28
 
29
+ def _safe_str(self, value):
30
+ """Coerce any value to a safe string representation."""
31
+ try:
32
+ if value is None:
33
+ return ""
34
+ return str(value)
35
+ except Exception:
36
+ return ""
37
+
29
38
  def emit(self, record):
30
39
  try:
40
+ # Defensive normalization of record fields to avoid None propagation
41
+ record.levelname = self._safe_str(getattr(record, "levelname", "INFO")) or "INFO"
31
42
  icon = self.level_icons.get(record.levelname, "")
32
43
  timestamp = self.formatTime(record)
44
+
45
+ # Ensure message formatting never returns None
33
46
  msg = self.format(record)
47
+ if msg is None:
48
+ msg = ""
49
+
50
+ msg = self._safe_str(msg)
51
+
34
52
  if icon:
35
53
  output = f"[{timestamp}] {icon} {msg}"
36
54
  else:
@@ -41,9 +59,9 @@ class RichHandler(logging.Handler):
41
59
  if just_finished_streaming:
42
60
  _streaming_context["just_finished"] = False # Reset after use
43
61
  # Don't add extra newline when transitioning from streaming
44
- self.console.print(Text(output), end="\n")
62
+ self.console.print(Text(self._safe_str(output)), end="\n")
45
63
  else:
46
- self.console.print(Text(output))
64
+ self.console.print(Text(self._safe_str(output)))
47
65
  except Exception:
48
66
  self.handleError(record)
49
67
 
@@ -61,15 +79,31 @@ class StructuredFileHandler(logging.FileHandler):
61
79
  Handler that outputs logs as structured JSON lines.
62
80
  """
63
81
 
82
+ def _coerce_json_safe(self, value):
83
+ """Ensure values are JSON-serializable and not None."""
84
+ if value is None:
85
+ return ""
86
+ try:
87
+ json.dumps(value)
88
+ return value
89
+ except Exception:
90
+ try:
91
+ return str(value)
92
+ except Exception:
93
+ return ""
94
+
64
95
  def emit(self, record):
65
96
  try:
97
+ # Normalize fields to avoid None values in JSON
66
98
  log_entry = {
67
99
  "timestamp": self.formatTime(record),
68
- "level": record.levelname,
69
- "name": record.name,
70
- "line": record.lineno,
71
- "message": record.getMessage(),
72
- "extra_data": getattr(record, "extra", {}),
100
+ "level": self._coerce_json_safe(getattr(record, "levelname", "")),
101
+ "name": self._coerce_json_safe(getattr(record, "name", "")),
102
+ "line": int(getattr(record, "lineno", 0) or 0),
103
+ "message": self._coerce_json_safe(
104
+ record.getMessage() if hasattr(record, "getMessage") else ""
105
+ ),
106
+ "extra_data": self._coerce_json_safe(getattr(record, "extra", {})),
73
107
  }
74
108
  self.stream.write(json.dumps(log_entry) + "\n")
75
109
  self.flush()
@@ -4,6 +4,7 @@ from .config_setup import ConfigSetup
4
4
  from .coordinator import SetupCoordinator
5
5
  from .environment_setup import EnvironmentSetup
6
6
  from .git_safety_setup import GitSafetySetup
7
+ from .template_setup import TemplateSetup
7
8
 
8
9
  __all__ = [
9
10
  "BaseSetup",
@@ -12,4 +13,5 @@ __all__ = [
12
13
  "EnvironmentSetup",
13
14
  "GitSafetySetup",
14
15
  "AgentSetup",
16
+ "TemplateSetup",
15
17
  ]
@@ -22,11 +22,11 @@ class AgentSetup(BaseSetup):
22
22
  def name(self) -> str:
23
23
  return "Agent"
24
24
 
25
- async def should_run(self, force_setup: bool = False) -> bool:
25
+ async def should_run(self, _force_setup: bool = False) -> bool:
26
26
  """Agent setup should run if an agent is provided."""
27
27
  return self.agent is not None
28
28
 
29
- async def execute(self, force_setup: bool = False) -> None:
29
+ async def execute(self, _force_setup: bool = False) -> None:
30
30
  """Initialize the agent with the current model."""
31
31
  if self.agent is not None:
32
32
  await ui.info(f"Initializing Agent({self.state_manager.session.current_model})")
@@ -22,12 +22,12 @@ class BaseSetup(ABC):
22
22
  pass
23
23
 
24
24
  @abstractmethod
25
- async def should_run(self, force_setup: bool = False) -> bool:
25
+ async def should_run(self, _force_setup: bool = False) -> bool:
26
26
  """Determine if this setup step should run."""
27
27
  pass
28
28
 
29
29
  @abstractmethod
30
- async def execute(self, force_setup: bool = False) -> None:
30
+ async def execute(self, _force_setup: bool = False) -> None:
31
31
  """Execute the setup step."""
32
32
  pass
33
33
 
@@ -272,17 +272,21 @@ class ConfigSetup(BaseSetup):
272
272
  self.state_manager.session.user_config = DEFAULT_USER_CONFIG.copy()
273
273
 
274
274
  # Apply CLI overrides
275
- if self.cli_config.get("key"):
275
+ if self.cli_config and self.cli_config.get("key"):
276
276
  # Ensure env dict exists
277
277
  if "env" not in self.state_manager.session.user_config:
278
278
  self.state_manager.session.user_config["env"] = {}
279
279
 
280
280
  # Determine which API key to set based on the model or baseurl
281
- if self.cli_config.get("baseurl") and "openrouter" in self.cli_config["baseurl"]:
281
+ if (
282
+ self.cli_config
283
+ and self.cli_config.get("baseurl")
284
+ and "openrouter" in self.cli_config["baseurl"]
285
+ ):
282
286
  self.state_manager.session.user_config["env"]["OPENROUTER_API_KEY"] = (
283
287
  self.cli_config["key"]
284
288
  )
285
- elif self.cli_config.get("model"):
289
+ elif self.cli_config and self.cli_config.get("model"):
286
290
  if "claude" in self.cli_config["model"] or "anthropic" in self.cli_config["model"]:
287
291
  self.state_manager.session.user_config["env"]["ANTHROPIC_API_KEY"] = (
288
292
  self.cli_config["key"]
@@ -301,12 +305,12 @@ class ConfigSetup(BaseSetup):
301
305
  self.cli_config["key"]
302
306
  )
303
307
 
304
- if self.cli_config.get("baseurl"):
308
+ if self.cli_config and self.cli_config.get("baseurl"):
305
309
  self.state_manager.session.user_config["env"]["OPENAI_BASE_URL"] = self.cli_config[
306
310
  "baseurl"
307
311
  ]
308
312
 
309
- if self.cli_config.get("model"):
313
+ if self.cli_config and self.cli_config.get("model"):
310
314
  model = self.cli_config["model"]
311
315
  # Require provider prefix
312
316
  if ":" not in model:
@@ -318,7 +322,7 @@ class ConfigSetup(BaseSetup):
318
322
 
319
323
  self.state_manager.session.user_config["default_model"] = model
320
324
 
321
- if self.cli_config.get("custom_context_window"):
325
+ if self.cli_config and self.cli_config.get("custom_context_window"):
322
326
  self.state_manager.session.user_config["context_window_size"] = self.cli_config[
323
327
  "custom_context_window"
324
328
  ]
@@ -32,12 +32,12 @@ class GitSafetySetup(BaseSetup):
32
32
  """Return the name of this setup step."""
33
33
  return "Git Safety"
34
34
 
35
- async def should_run(self, force: bool = False) -> bool:
35
+ async def should_run(self, _force: bool = False) -> bool:
36
36
  """Check if we should run git safety setup."""
37
37
  # Always run unless user has explicitly disabled it
38
38
  return not self.state_manager.session.user_config.get("skip_git_safety", False)
39
39
 
40
- async def execute(self, force: bool = False) -> None:
40
+ async def execute(self, _force: bool = False) -> None:
41
41
  """Create a safety branch for TunaCode operations."""
42
42
  try:
43
43
  # Check if git is installed
@@ -123,6 +123,16 @@ class GitSafetySetup(BaseSetup):
123
123
  )
124
124
  # Save preference
125
125
  self.state_manager.session.user_config["skip_git_safety"] = True
126
+ # Save the updated configuration to disk
127
+ try:
128
+ from tunacode.utils.user_configuration import save_config
129
+
130
+ save_config(self.state_manager)
131
+ except Exception as e:
132
+ # Log the error but don't fail the setup process
133
+ import logging
134
+
135
+ logging.warning(f"Failed to save skip_git_safety preference: {e}")
126
136
  return
127
137
 
128
138
  # Create and checkout the new branch
@@ -132,6 +142,7 @@ class GitSafetySetup(BaseSetup):
132
142
  ["git", "show-ref", "--verify", f"refs/heads/{new_branch}"],
133
143
  capture_output=True,
134
144
  check=False,
145
+ text=True,
135
146
  )
136
147
 
137
148
  if result.returncode == 0:
@@ -0,0 +1,75 @@
1
+ """Module: tunacode.core.setup.template_setup
2
+
3
+ Template directory initialization for the TunaCode CLI.
4
+ Handles creation of template directories and ensures proper structure.
5
+ """
6
+
7
+ import platform
8
+ from pathlib import Path
9
+
10
+ from tunacode.core.setup.base import BaseSetup
11
+ from tunacode.core.state import StateManager
12
+ from tunacode.ui import console as ui
13
+
14
+
15
+ class TemplateSetup(BaseSetup):
16
+ """Setup step for template directory structure."""
17
+
18
+ def __init__(self, state_manager: StateManager):
19
+ super().__init__(state_manager)
20
+ # Use same config directory as main configuration
21
+ self.config_dir = Path.home() / ".config" / "tunacode"
22
+ self.template_dir = self.config_dir / "templates"
23
+
24
+ @property
25
+ def name(self) -> str:
26
+ return "Template Directory"
27
+
28
+ async def should_run(self, force_setup: bool = False) -> bool:
29
+ """Run if template directory doesn't exist or force setup is requested."""
30
+ return force_setup or not self.template_dir.exists()
31
+
32
+ async def execute(self, force_setup: bool = False) -> None:
33
+ """Create template directory structure."""
34
+ try:
35
+ # Create main template directory
36
+ self.template_dir.mkdir(parents=True, exist_ok=True)
37
+
38
+ # Create subdirectories for organization (optional, for future use)
39
+ subdirs = ["project", "tool", "config"]
40
+ for subdir in subdirs:
41
+ subdir_path = self.template_dir / subdir
42
+ subdir_path.mkdir(exist_ok=True)
43
+
44
+ # Set appropriate permissions on Unix-like systems
45
+ if platform.system() != "Windows":
46
+ import os
47
+
48
+ os.chmod(self.template_dir, 0o755)
49
+ for subdir in subdirs:
50
+ os.chmod(self.template_dir / subdir, 0o755)
51
+
52
+ await ui.info(f"Created template directory structure at: {self.template_dir}")
53
+
54
+ except PermissionError:
55
+ await ui.error(
56
+ f"Permission denied: Cannot create template directory at {self.template_dir}"
57
+ )
58
+ raise
59
+ except OSError as e:
60
+ await ui.error(f"Failed to create template directory: {str(e)}")
61
+ raise
62
+
63
+ async def validate(self) -> bool:
64
+ """Validate that template directory exists and is accessible."""
65
+ if not self.template_dir.exists():
66
+ return False
67
+
68
+ # Check if directory is writable
69
+ try:
70
+ test_file = self.template_dir / ".test_write"
71
+ test_file.touch()
72
+ test_file.unlink()
73
+ return True
74
+ except Exception:
75
+ return False
tunacode/core/state.py CHANGED
@@ -6,7 +6,7 @@ Handles user preferences, conversation history, and runtime state.
6
6
 
7
7
  import uuid
8
8
  from dataclasses import dataclass, field
9
- from typing import Any, Optional
9
+ from typing import TYPE_CHECKING, Any, Optional
10
10
 
11
11
  from tunacode.types import (
12
12
  DeviceId,
@@ -21,6 +21,9 @@ from tunacode.types import (
21
21
  from tunacode.utils.message_utils import get_message_content
22
22
  from tunacode.utils.token_counter import estimate_tokens
23
23
 
24
+ if TYPE_CHECKING:
25
+ from tunacode.core.tool_handler import ToolHandler
26
+
24
27
 
25
28
  @dataclass
26
29
  class SessionState:
@@ -91,11 +94,19 @@ class SessionState:
91
94
  class StateManager:
92
95
  def __init__(self):
93
96
  self._session = SessionState()
97
+ self._tool_handler: Optional["ToolHandler"] = None
94
98
 
95
99
  @property
96
100
  def session(self) -> SessionState:
97
101
  return self._session
98
102
 
103
+ @property
104
+ def tool_handler(self) -> Optional["ToolHandler"]:
105
+ return self._tool_handler
106
+
107
+ def set_tool_handler(self, handler: "ToolHandler") -> None:
108
+ self._tool_handler = handler
109
+
99
110
  def add_todo(self, todo: TodoItem) -> None:
100
111
  self._session.todos.append(todo)
101
112
 
@@ -112,7 +123,7 @@ class StateManager:
112
123
  def push_recursive_context(self, context: dict[str, Any]) -> None:
113
124
  """Push a new context onto the recursive execution stack."""
114
125
  self._session.recursive_context_stack.append(context)
115
- self._session.current_recursion_depth += 1
126
+ self._session.current_recursion_depth = (self._session.current_recursion_depth or 0) + 1
116
127
 
117
128
  def pop_recursive_context(self) -> Optional[dict[str, Any]]:
118
129
  """Pop the current context from the recursive execution stack."""
@@ -35,9 +35,13 @@ class ApiResponseParser:
35
35
 
36
36
  # The pydantic-ai Usage object standardizes keys to 'request_tokens'
37
37
  # and 'response_tokens'. We access them as attributes.
38
+ # Ensure None values are converted to 0
39
+ prompt_tokens = getattr(usage, "request_tokens", 0)
40
+ completion_tokens = getattr(usage, "response_tokens", 0)
41
+
38
42
  parsed_data = {
39
- "prompt_tokens": getattr(usage, "request_tokens", 0),
40
- "completion_tokens": getattr(usage, "response_tokens", 0),
43
+ "prompt_tokens": prompt_tokens if prompt_tokens is not None else 0,
44
+ "completion_tokens": completion_tokens if completion_tokens is not None else 0,
41
45
  "model_name": actual_model_name,
42
46
  }
43
47
 
@@ -46,7 +46,11 @@ class UsageTracker(UsageTrackerProtocol):
46
46
 
47
47
  except Exception as e:
48
48
  if self.state_manager.session.show_thoughts:
49
+ import traceback
50
+
49
51
  await ui.error(f"Error during cost calculation: {e}")
52
+ # Log the full traceback for debugging
53
+ await ui.debug(f"Traceback: {traceback.format_exc()}")
50
54
 
51
55
  def _calculate_cost(self, parsed_data: dict) -> float:
52
56
  """Calculates the cost for the given parsed data."""
@@ -81,15 +85,35 @@ class UsageTracker(UsageTrackerProtocol):
81
85
  if session.session_total_usage is None:
82
86
  session.session_total_usage = {"prompt_tokens": 0, "completion_tokens": 0, "cost": 0.0}
83
87
 
88
+ # Normalize values defensively to avoid None propagation
89
+ try:
90
+ prompt_tokens = int(prompt_tokens or 0)
91
+ except (TypeError, ValueError):
92
+ prompt_tokens = 0
93
+ try:
94
+ completion_tokens = int(completion_tokens or 0)
95
+ except (TypeError, ValueError):
96
+ completion_tokens = 0
97
+ try:
98
+ cost = float(cost or 0.0)
99
+ except (TypeError, ValueError):
100
+ cost = 0.0
101
+
84
102
  # Update last call usage
85
103
  session.last_call_usage["prompt_tokens"] = prompt_tokens
86
104
  session.last_call_usage["completion_tokens"] = completion_tokens
87
105
  session.last_call_usage["cost"] = cost
88
106
 
89
- # Accumulate session totals
90
- session.session_total_usage["prompt_tokens"] += prompt_tokens
91
- session.session_total_usage["completion_tokens"] += completion_tokens
92
- session.session_total_usage["cost"] += cost
107
+ # Accumulate session totals with normalization
108
+ session.session_total_usage["prompt_tokens"] = (
109
+ int(session.session_total_usage.get("prompt_tokens", 0) or 0) + prompt_tokens
110
+ )
111
+ session.session_total_usage["completion_tokens"] = (
112
+ int(session.session_total_usage.get("completion_tokens", 0) or 0) + completion_tokens
113
+ )
114
+ session.session_total_usage["cost"] = (
115
+ float(session.session_total_usage.get("cost", 0.0) or 0.0) + cost
116
+ )
93
117
 
94
118
  async def _display_summary(self):
95
119
  """Formats and prints the usage summary to the console."""
@@ -106,9 +130,15 @@ class UsageTracker(UsageTrackerProtocol):
106
130
  last_cost = session.last_call_usage["cost"]
107
131
  session_cost = session.session_total_usage["cost"]
108
132
 
133
+ # Ensure tokens are not None before arithmetic operations
134
+ prompt_safe = prompt if prompt is not None else 0
135
+ completion_safe = completion if completion is not None else 0
136
+ last_cost_safe = last_cost if last_cost is not None else 0.0
137
+ session_cost_safe = session_cost if session_cost is not None else 0.0
138
+
109
139
  usage_summary = (
110
- f"[ Tokens: {prompt + completion:,} (P: {prompt:,}, C: {completion:,}) | "
111
- f"Cost: ${last_cost:.4f} | "
112
- f"Session Total: ${session_cost:.4f} ]"
140
+ f"[ Tokens: {prompt_safe + completion_safe:,} (P: {prompt_safe:,}, C: {completion_safe:,}) | "
141
+ f"Cost: ${last_cost_safe:.4f} | "
142
+ f"Session Total: ${session_cost_safe:.4f} ]"
113
143
  )
114
144
  await ui.muted(usage_summary)
@@ -2,9 +2,17 @@
2
2
  Tool handling business logic, separated from UI concerns.
3
3
  """
4
4
 
5
+ from typing import Optional
6
+
5
7
  from tunacode.constants import READ_ONLY_TOOLS
6
8
  from tunacode.core.state import StateManager
7
- from tunacode.types import ToolArgs, ToolConfirmationRequest, ToolConfirmationResponse, ToolName
9
+ from tunacode.templates.loader import Template
10
+ from tunacode.types import (
11
+ ToolArgs,
12
+ ToolConfirmationRequest,
13
+ ToolConfirmationResponse,
14
+ ToolName,
15
+ )
8
16
 
9
17
 
10
18
  class ToolHandler:
@@ -12,6 +20,16 @@ class ToolHandler:
12
20
 
13
21
  def __init__(self, state_manager: StateManager):
14
22
  self.state = state_manager
23
+ self.active_template: Optional[Template] = None
24
+
25
+ def set_active_template(self, template: Optional[Template]) -> None:
26
+ """
27
+ Set the currently active template.
28
+
29
+ Args:
30
+ template: The template to activate, or None to clear the active template.
31
+ """
32
+ self.active_template = template
15
33
 
16
34
  def should_confirm(self, tool_name: ToolName) -> bool:
17
35
  """
@@ -27,6 +45,11 @@ class ToolHandler:
27
45
  if is_read_only_tool(tool_name):
28
46
  return False
29
47
 
48
+ # Check if tool is allowed by active template
49
+ if self.active_template and self.active_template.allowed_tools:
50
+ if tool_name in self.active_template.allowed_tools:
51
+ return False
52
+
30
53
  return not (self.state.session.yolo or tool_name in self.state.session.tool_ignore)
31
54
 
32
55
  def process_confirmation(self, response: ToolConfirmationResponse, tool_name: ToolName) -> bool: