ripperdoc 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. ripperdoc/__init__.py +1 -1
  2. ripperdoc/cli/cli.py +9 -2
  3. ripperdoc/cli/commands/agents_cmd.py +8 -4
  4. ripperdoc/cli/commands/context_cmd.py +3 -3
  5. ripperdoc/cli/commands/cost_cmd.py +5 -0
  6. ripperdoc/cli/commands/doctor_cmd.py +12 -4
  7. ripperdoc/cli/commands/memory_cmd.py +6 -13
  8. ripperdoc/cli/commands/models_cmd.py +36 -6
  9. ripperdoc/cli/commands/resume_cmd.py +4 -2
  10. ripperdoc/cli/commands/status_cmd.py +1 -1
  11. ripperdoc/cli/ui/rich_ui.py +135 -2
  12. ripperdoc/cli/ui/thinking_spinner.py +128 -0
  13. ripperdoc/core/agents.py +174 -6
  14. ripperdoc/core/config.py +9 -1
  15. ripperdoc/core/default_tools.py +6 -0
  16. ripperdoc/core/providers/__init__.py +47 -0
  17. ripperdoc/core/providers/anthropic.py +147 -0
  18. ripperdoc/core/providers/base.py +236 -0
  19. ripperdoc/core/providers/gemini.py +496 -0
  20. ripperdoc/core/providers/openai.py +253 -0
  21. ripperdoc/core/query.py +337 -141
  22. ripperdoc/core/query_utils.py +65 -24
  23. ripperdoc/core/system_prompt.py +67 -61
  24. ripperdoc/core/tool.py +12 -3
  25. ripperdoc/sdk/client.py +12 -1
  26. ripperdoc/tools/ask_user_question_tool.py +433 -0
  27. ripperdoc/tools/background_shell.py +104 -18
  28. ripperdoc/tools/bash_tool.py +33 -13
  29. ripperdoc/tools/enter_plan_mode_tool.py +223 -0
  30. ripperdoc/tools/exit_plan_mode_tool.py +150 -0
  31. ripperdoc/tools/file_edit_tool.py +13 -0
  32. ripperdoc/tools/file_read_tool.py +16 -0
  33. ripperdoc/tools/file_write_tool.py +13 -0
  34. ripperdoc/tools/glob_tool.py +5 -1
  35. ripperdoc/tools/ls_tool.py +14 -10
  36. ripperdoc/tools/mcp_tools.py +113 -4
  37. ripperdoc/tools/multi_edit_tool.py +12 -0
  38. ripperdoc/tools/notebook_edit_tool.py +12 -0
  39. ripperdoc/tools/task_tool.py +88 -5
  40. ripperdoc/tools/todo_tool.py +1 -3
  41. ripperdoc/tools/tool_search_tool.py +8 -4
  42. ripperdoc/utils/file_watch.py +134 -0
  43. ripperdoc/utils/git_utils.py +36 -38
  44. ripperdoc/utils/json_utils.py +1 -2
  45. ripperdoc/utils/log.py +3 -4
  46. ripperdoc/utils/mcp.py +49 -10
  47. ripperdoc/utils/memory.py +1 -3
  48. ripperdoc/utils/message_compaction.py +5 -11
  49. ripperdoc/utils/messages.py +9 -13
  50. ripperdoc/utils/output_utils.py +1 -3
  51. ripperdoc/utils/prompt.py +17 -0
  52. ripperdoc/utils/session_usage.py +7 -0
  53. ripperdoc/utils/shell_utils.py +159 -0
  54. ripperdoc/utils/token_estimation.py +33 -0
  55. {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.4.dist-info}/METADATA +3 -1
  56. ripperdoc-0.2.4.dist-info/RECORD +99 -0
  57. ripperdoc-0.2.2.dist-info/RECORD +0 -86
  58. {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.4.dist-info}/WHEEL +0 -0
  59. {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.4.dist-info}/entry_points.txt +0 -0
  60. {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.4.dist-info}/licenses/LICENSE +0 -0
  61. {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.4.dist-info}/top_level.txt +0 -0
@@ -42,10 +42,10 @@ def read_gitignore_patterns(path: Path) -> List[str]:
42
42
  """Read .gitignore patterns from a directory and its parent directories."""
43
43
  patterns: List[str] = []
44
44
  current = path
45
-
45
+
46
46
  # Read .gitignore from current directory up to git root
47
47
  git_root = get_git_root(path)
48
-
48
+
49
49
  while current and (git_root is None or current.is_relative_to(git_root)):
50
50
  gitignore_file = current / ".gitignore"
51
51
  if gitignore_file.exists():
@@ -57,7 +57,7 @@ def read_gitignore_patterns(path: Path) -> List[str]:
57
57
  patterns.append(line)
58
58
  except (IOError, UnicodeDecodeError):
59
59
  pass
60
-
60
+
61
61
  # Also check for .git/info/exclude
62
62
  git_info_exclude = current / ".git" / "info" / "exclude"
63
63
  if git_info_exclude.exists():
@@ -69,11 +69,11 @@ def read_gitignore_patterns(path: Path) -> List[str]:
69
69
  patterns.append(line)
70
70
  except (IOError, UnicodeDecodeError):
71
71
  pass
72
-
72
+
73
73
  if current.parent == current: # Reached root
74
74
  break
75
75
  current = current.parent
76
-
76
+
77
77
  # Add global gitignore patterns
78
78
  global_gitignore = Path.home() / ".gitignore"
79
79
  if global_gitignore.exists():
@@ -85,39 +85,39 @@ def read_gitignore_patterns(path: Path) -> List[str]:
85
85
  patterns.append(line)
86
86
  except (IOError, UnicodeDecodeError):
87
87
  pass
88
-
88
+
89
89
  return patterns
90
90
 
91
91
 
92
92
  def parse_gitignore_pattern(pattern: str, root_path: Path) -> Tuple[str, Optional[Path]]:
93
93
  """Parse a gitignore pattern and return (relative_pattern, root)."""
94
94
  pattern = pattern.strip()
95
-
95
+
96
96
  # Handle absolute paths
97
97
  if pattern.startswith("/"):
98
98
  return pattern[1:], root_path
99
-
99
+
100
100
  # Handle patterns relative to home directory
101
101
  if pattern.startswith("~/"):
102
102
  home_pattern = pattern[2:]
103
103
  return home_pattern, Path.home()
104
-
104
+
105
105
  # Handle patterns with leading slash (relative to repository root)
106
106
  if pattern.startswith("/"):
107
107
  return pattern[1:], root_path
108
-
108
+
109
109
  # Default: pattern is relative to the directory containing .gitignore
110
110
  return pattern, None
111
111
 
112
112
 
113
113
  def build_ignore_patterns_map(
114
- root_path: Path,
114
+ root_path: Path,
115
115
  user_ignore_patterns: Optional[List[str]] = None,
116
- include_gitignore: bool = True
116
+ include_gitignore: bool = True,
117
117
  ) -> Dict[Optional[Path], List[str]]:
118
118
  """Build a map of ignore patterns by root directory."""
119
119
  ignore_map: Dict[Optional[Path], List[str]] = {}
120
-
120
+
121
121
  # Add user-provided ignore patterns
122
122
  if user_ignore_patterns:
123
123
  for pattern in user_ignore_patterns:
@@ -125,7 +125,7 @@ def build_ignore_patterns_map(
125
125
  if pattern_root not in ignore_map:
126
126
  ignore_map[pattern_root] = []
127
127
  ignore_map[pattern_root].append(relative_pattern)
128
-
128
+
129
129
  # Add .gitignore patterns
130
130
  if include_gitignore and is_git_repository(root_path):
131
131
  gitignore_patterns = read_gitignore_patterns(root_path)
@@ -134,31 +134,29 @@ def build_ignore_patterns_map(
134
134
  if pattern_root not in ignore_map:
135
135
  ignore_map[pattern_root] = []
136
136
  ignore_map[pattern_root].append(relative_pattern)
137
-
137
+
138
138
  return ignore_map
139
139
 
140
140
 
141
141
  def should_ignore_path(
142
- path: Path,
143
- root_path: Path,
144
- ignore_map: Dict[Optional[Path], List[str]]
142
+ path: Path, root_path: Path, ignore_map: Dict[Optional[Path], List[str]]
145
143
  ) -> bool:
146
144
  """Check if a path should be ignored based on ignore patterns."""
147
145
  # Check against each root in the ignore map
148
146
  for pattern_root, patterns in ignore_map.items():
149
147
  # Determine the actual root to use for pattern matching
150
148
  actual_root = pattern_root if pattern_root is not None else root_path
151
-
149
+
152
150
  try:
153
151
  # Get relative path from actual_root
154
152
  rel_path = path.relative_to(actual_root).as_posix()
155
153
  except ValueError:
156
154
  # Path is not under this root, skip
157
155
  continue
158
-
156
+
159
157
  # For directories, also check with trailing slash
160
158
  rel_path_dir = f"{rel_path}/" if path.is_dir() else rel_path
161
-
159
+
162
160
  # Check each pattern
163
161
  for pattern in patterns:
164
162
  # Handle directory-specific patterns
@@ -166,14 +164,14 @@ def should_ignore_path(
166
164
  if not path.is_dir():
167
165
  continue
168
166
  pattern_without_slash = pattern[:-1]
169
- if fnmatch.fnmatch(rel_path, pattern_without_slash) or \
170
- fnmatch.fnmatch(rel_path_dir, pattern):
167
+ if fnmatch.fnmatch(rel_path, pattern_without_slash) or fnmatch.fnmatch(
168
+ rel_path_dir, pattern
169
+ ):
171
170
  return True
172
171
  else:
173
- if fnmatch.fnmatch(rel_path, pattern) or \
174
- fnmatch.fnmatch(rel_path_dir, pattern):
172
+ if fnmatch.fnmatch(rel_path, pattern) or fnmatch.fnmatch(rel_path_dir, pattern):
175
173
  return True
176
-
174
+
177
175
  return False
178
176
 
179
177
 
@@ -181,10 +179,10 @@ def get_git_status_files(root_path: Path) -> Tuple[List[str], List[str]]:
181
179
  """Get tracked and untracked files from git status."""
182
180
  tracked: List[str] = []
183
181
  untracked: List[str] = []
184
-
182
+
185
183
  if not is_git_repository(root_path):
186
184
  return tracked, untracked
187
-
185
+
188
186
  try:
189
187
  # Get tracked files (modified, added, etc.)
190
188
  result = subprocess.run(
@@ -194,25 +192,25 @@ def get_git_status_files(root_path: Path) -> Tuple[List[str], List[str]]:
194
192
  text=True,
195
193
  timeout=10,
196
194
  )
197
-
195
+
198
196
  if result.returncode == 0:
199
197
  for line in result.stdout.strip().split("\n"):
200
198
  if line:
201
199
  status = line[:2].strip()
202
200
  file_path = line[3:].strip()
203
-
201
+
204
202
  # Remove quotes if present
205
203
  if file_path.startswith('"') and file_path.endswith('"'):
206
204
  file_path = file_path[1:-1]
207
-
205
+
208
206
  if status == "??": # Untracked
209
207
  untracked.append(file_path)
210
208
  else: # Tracked (modified, added, etc.)
211
209
  tracked.append(file_path)
212
-
210
+
213
211
  except (subprocess.SubprocessError, FileNotFoundError):
214
212
  pass
215
-
213
+
216
214
  return tracked, untracked
217
215
 
218
216
 
@@ -220,7 +218,7 @@ def get_current_git_branch(root_path: Path) -> Optional[str]:
220
218
  """Get the current git branch name."""
221
219
  if not is_git_repository(root_path):
222
220
  return None
223
-
221
+
224
222
  try:
225
223
  result = subprocess.run(
226
224
  ["git", "branch", "--show-current"],
@@ -233,7 +231,7 @@ def get_current_git_branch(root_path: Path) -> Optional[str]:
233
231
  return result.stdout.strip()
234
232
  except (subprocess.SubprocessError, FileNotFoundError):
235
233
  pass
236
-
234
+
237
235
  return None
238
236
 
239
237
 
@@ -241,7 +239,7 @@ def get_git_commit_hash(root_path: Path) -> Optional[str]:
241
239
  """Get the current git commit hash."""
242
240
  if not is_git_repository(root_path):
243
241
  return None
244
-
242
+
245
243
  try:
246
244
  result = subprocess.run(
247
245
  ["git", "rev-parse", "HEAD"],
@@ -254,7 +252,7 @@ def get_git_commit_hash(root_path: Path) -> Optional[str]:
254
252
  return result.stdout.strip()[:8] # Short hash
255
253
  except (subprocess.SubprocessError, FileNotFoundError):
256
254
  pass
257
-
255
+
258
256
  return None
259
257
 
260
258
 
@@ -262,7 +260,7 @@ def is_working_directory_clean(root_path: Path) -> bool:
262
260
  """Check if the working directory is clean (no uncommitted changes)."""
263
261
  if not is_git_repository(root_path):
264
262
  return True
265
-
263
+
266
264
  try:
267
265
  result = subprocess.run(
268
266
  ["git", "status", "--porcelain"],
@@ -12,8 +12,7 @@ logger = get_logger()
12
12
 
13
13
 
14
14
  def safe_parse_json(json_text: Optional[str], log_error: bool = True) -> Optional[Any]:
15
- """Best-effort JSON.parse wrapper that returns None on failure.
16
- """
15
+ """Best-effort JSON.parse wrapper that returns None on failure."""
17
16
  if not json_text:
18
17
  return None
19
18
  try:
ripperdoc/utils/log.py CHANGED
@@ -54,9 +54,7 @@ class StructuredFormatter(logging.Formatter):
54
54
  }
55
55
  if extras:
56
56
  try:
57
- serialized = json.dumps(
58
- extras, sort_keys=True, ensure_ascii=True, default=str
59
- )
57
+ serialized = json.dumps(extras, sort_keys=True, ensure_ascii=True, default=str)
60
58
  except Exception:
61
59
  serialized = str(extras)
62
60
  return f"{message} | {serialized}"
@@ -103,7 +101,8 @@ class RipperdocLogger:
103
101
  # Swallow errors while rotating handlers; console logging should continue.
104
102
  self.logger.exception("[logging] Failed to remove existing file handler")
105
103
 
106
- file_handler = logging.FileHandler(log_file)
104
+ # Use UTF-8 to avoid Windows code page encoding errors when logs contain non-ASCII text.
105
+ file_handler = logging.FileHandler(log_file, encoding="utf-8")
107
106
  file_handler.setLevel(logging.DEBUG)
108
107
  file_formatter = StructuredFormatter("%(asctime)s [%(levelname)s] %(message)s")
109
108
  file_handler.setFormatter(file_formatter)
ripperdoc/utils/mcp.py CHANGED
@@ -5,6 +5,7 @@ from __future__ import annotations
5
5
  import asyncio
6
6
  import contextvars
7
7
  import json
8
+ import shlex
8
9
  from contextlib import AsyncExitStack
9
10
  from dataclasses import dataclass, field, replace
10
11
  from pathlib import Path
@@ -12,16 +13,16 @@ from typing import Any, Dict, List, Optional
12
13
 
13
14
  from ripperdoc import __version__
14
15
  from ripperdoc.utils.log import get_logger
15
- from ripperdoc.utils.message_compaction import estimate_tokens_from_text
16
+ from ripperdoc.utils.token_estimation import estimate_tokens
16
17
 
17
18
  logger = get_logger()
18
19
 
19
20
  try:
20
- import mcp.types as mcp_types
21
- from mcp.client.session import ClientSession
22
- from mcp.client.sse import sse_client
23
- from mcp.client.stdio import StdioServerParameters, stdio_client
24
- from mcp.client.streamable_http import streamablehttp_client
21
+ import mcp.types as mcp_types # type: ignore[import-not-found]
22
+ from mcp.client.session import ClientSession # type: ignore[import-not-found]
23
+ from mcp.client.sse import sse_client # type: ignore[import-not-found]
24
+ from mcp.client.stdio import StdioServerParameters, stdio_client # type: ignore[import-not-found]
25
+ from mcp.client.streamable_http import streamablehttp_client # type: ignore[import-not-found]
25
26
 
26
27
  MCP_AVAILABLE = True
27
28
  except Exception: # pragma: no cover - handled gracefully at runtime
@@ -97,10 +98,48 @@ def _ensure_str_dict(raw: object) -> Dict[str, str]:
97
98
  return result
98
99
 
99
100
 
101
+ def _normalize_command(
102
+ raw_command: Any, raw_args: Any
103
+ ) -> tuple[Optional[str], List[str]]:
104
+ """Normalize MCP server command/args.
105
+
106
+ Supports:
107
+ - command as list -> first element is executable, rest are args
108
+ - command as string with spaces -> shlex.split into executable/args (when args empty)
109
+ - command as plain string -> used as-is
110
+ """
111
+ args: List[str] = []
112
+ if isinstance(raw_args, list):
113
+ args = [str(a) for a in raw_args]
114
+
115
+ # Command provided as list: treat first token as command.
116
+ if isinstance(raw_command, list):
117
+ tokens = [str(t) for t in raw_command if str(t)]
118
+ if not tokens:
119
+ return None, args
120
+ return tokens[0], tokens[1:] + args
121
+
122
+ if not isinstance(raw_command, str):
123
+ return None, args
124
+
125
+ command_str = raw_command.strip()
126
+ if not command_str:
127
+ return None, args
128
+
129
+ if not args and (" " in command_str or "\t" in command_str):
130
+ try:
131
+ tokens = shlex.split(command_str)
132
+ except ValueError:
133
+ tokens = [command_str]
134
+ if tokens:
135
+ return tokens[0], tokens[1:]
136
+
137
+ return command_str, args
138
+
139
+
100
140
  def _parse_server(name: str, raw: Dict[str, Any]) -> McpServerInfo:
101
141
  server_type = str(raw.get("type") or raw.get("transport") or "").strip().lower()
102
- command = raw.get("command")
103
- args = raw.get("args") if isinstance(raw.get("args"), list) else []
142
+ command, args = _normalize_command(raw.get("command"), raw.get("args"))
104
143
  url = str(raw.get("url") or raw.get("uri") or "").strip() or None
105
144
 
106
145
  if not server_type:
@@ -121,7 +160,7 @@ def _parse_server(name: str, raw: Dict[str, Any]) -> McpServerInfo:
121
160
  type=server_type,
122
161
  url=url,
123
162
  description=description,
124
- command=str(command) if isinstance(command, str) else None,
163
+ command=command,
125
164
  args=[str(a) for a in args] if args else [],
126
165
  env=env,
127
166
  headers=headers,
@@ -482,7 +521,7 @@ def format_mcp_instructions(servers: List[McpServerInfo]) -> str:
482
521
  def estimate_mcp_tokens(servers: List[McpServerInfo]) -> int:
483
522
  """Estimate token usage for MCP instructions."""
484
523
  mcp_text = format_mcp_instructions(servers)
485
- return estimate_tokens_from_text(mcp_text)
524
+ return estimate_tokens(mcp_text)
486
525
 
487
526
 
488
527
  __all__ = [
ripperdoc/utils/memory.py CHANGED
@@ -72,9 +72,7 @@ def _read_file_with_type(file_path: Path, file_type: str) -> Optional[MemoryFile
72
72
  content = file_path.read_text(encoding="utf-8", errors="ignore")
73
73
  return MemoryFile(path=str(file_path), type=file_type, content=content)
74
74
  except PermissionError:
75
- logger.exception(
76
- "[memory] Permission error reading file", extra={"path": str(file_path)}
77
- )
75
+ logger.exception("[memory] Permission error reading file", extra={"path": str(file_path)})
78
76
  return None
79
77
  except OSError:
80
78
  logger.exception("[memory] OS error reading file", extra={"path": str(file_path)})
@@ -3,13 +3,13 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import json
6
- import math
7
6
  import os
8
7
  from dataclasses import dataclass
9
8
  from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union
10
9
 
11
10
  from ripperdoc.core.config import GlobalConfig, ModelProfile, get_global_config
12
11
  from ripperdoc.utils.log import get_logger
12
+ from ripperdoc.utils.token_estimation import estimate_tokens
13
13
  from ripperdoc.utils.messages import (
14
14
  AssistantMessage,
15
15
  MessageContent,
@@ -140,10 +140,8 @@ def _parse_truthy_env_value(value: Optional[str]) -> bool:
140
140
 
141
141
 
142
142
  def estimate_tokens_from_text(text: str) -> int:
143
- """Rough token estimate using a 4-characters-per-token rule."""
144
- if not text:
145
- return 0
146
- return max(1, math.ceil(len(text) / 4))
143
+ """Estimate token count using shared token estimation helper."""
144
+ return estimate_tokens(text)
147
145
 
148
146
 
149
147
  def _stringify_content(content: Union[str, List[MessageContent], None]) -> str:
@@ -402,9 +400,7 @@ def find_latest_assistant_usage_tokens(
402
400
  if tokens > 0:
403
401
  return tokens
404
402
  except Exception:
405
- logger.debug(
406
- "[message_compaction] Failed to parse usage tokens", exc_info=True
407
- )
403
+ logger.debug("[message_compaction] Failed to parse usage tokens", exc_info=True)
408
404
  continue
409
405
  return 0
410
406
 
@@ -441,9 +437,7 @@ def _run_cleanup_callbacks() -> None:
441
437
  try:
442
438
  callback()
443
439
  except Exception as exc:
444
- logger.debug(
445
- f"[message_compaction] Cleanup callback failed: {exc}", exc_info=True
446
- )
440
+ logger.debug(f"[message_compaction] Cleanup callback failed: {exc}", exc_info=True)
447
441
 
448
442
 
449
443
  def _normalize_tool_use_id(block: Any) -> str:
@@ -31,7 +31,7 @@ class MessageContent(BaseModel):
31
31
  id: Optional[str] = None
32
32
  tool_use_id: Optional[str] = None
33
33
  name: Optional[str] = None
34
- input: Optional[Dict[str, Any]] = None
34
+ input: Optional[Dict[str, object]] = None
35
35
  is_error: Optional[bool] = None
36
36
 
37
37
 
@@ -120,7 +120,7 @@ class Message(BaseModel):
120
120
  content: Union[str, List[MessageContent]]
121
121
  uuid: str = ""
122
122
 
123
- def __init__(self, **data: Any) -> None:
123
+ def __init__(self, **data: object) -> None:
124
124
  if "uuid" not in data or not data["uuid"]:
125
125
  data["uuid"] = str(uuid4())
126
126
  super().__init__(**data)
@@ -132,9 +132,9 @@ class UserMessage(BaseModel):
132
132
  type: str = "user"
133
133
  message: Message
134
134
  uuid: str = ""
135
- tool_use_result: Optional[Any] = None
135
+ tool_use_result: Optional[object] = None
136
136
 
137
- def __init__(self, **data: Any) -> None:
137
+ def __init__(self, **data: object) -> None:
138
138
  if "uuid" not in data or not data["uuid"]:
139
139
  data["uuid"] = str(uuid4())
140
140
  super().__init__(**data)
@@ -150,7 +150,7 @@ class AssistantMessage(BaseModel):
150
150
  duration_ms: float = 0.0
151
151
  is_api_error_message: bool = False
152
152
 
153
- def __init__(self, **data: Any) -> None:
153
+ def __init__(self, **data: object) -> None:
154
154
  if "uuid" not in data or not data["uuid"]:
155
155
  data["uuid"] = str(uuid4())
156
156
  super().__init__(**data)
@@ -167,14 +167,14 @@ class ProgressMessage(BaseModel):
167
167
  sibling_tool_use_ids: set[str] = set()
168
168
  model_config = ConfigDict(arbitrary_types_allowed=True)
169
169
 
170
- def __init__(self, **data: Any) -> None:
170
+ def __init__(self, **data: object) -> None:
171
171
  if "uuid" not in data or not data["uuid"]:
172
172
  data["uuid"] = str(uuid4())
173
173
  super().__init__(**data)
174
174
 
175
175
 
176
176
  def create_user_message(
177
- content: Union[str, List[Dict[str, Any]]], tool_use_result: Optional[Any] = None
177
+ content: Union[str, List[Dict[str, Any]]], tool_use_result: Optional[object] = None
178
178
  ) -> UserMessage:
179
179
  """Create a user message."""
180
180
  if isinstance(content, str):
@@ -371,9 +371,7 @@ def normalize_messages_for_api(
371
371
  api_blocks.append(_content_block_to_api(block))
372
372
  normalized.append({"role": "user", "content": api_blocks})
373
373
  else:
374
- normalized.append(
375
- {"role": "user", "content": user_content} # type: ignore
376
- )
374
+ normalized.append({"role": "user", "content": user_content}) # type: ignore
377
375
  elif msg_type == "assistant":
378
376
  asst_content = _msg_content(msg)
379
377
  if isinstance(asst_content, list):
@@ -428,9 +426,7 @@ def normalize_messages_for_api(
428
426
  api_blocks.append(_content_block_to_api(block))
429
427
  normalized.append({"role": "assistant", "content": api_blocks})
430
428
  else:
431
- normalized.append(
432
- {"role": "assistant", "content": asst_content} # type: ignore
433
- )
429
+ normalized.append({"role": "assistant", "content": asst_content}) # type: ignore
434
430
 
435
431
  logger.debug(
436
432
  f"[normalize_messages_for_api] protocol={protocol} tool_mode={effective_tool_mode} "
@@ -151,9 +151,7 @@ def truncate_output(text: str, max_chars: int = MAX_OUTPUT_CHARS) -> dict[str, A
151
151
  available = max(0, max_chars - len(marker))
152
152
  keep_start = min(TRUNCATE_KEEP_START, available // 2)
153
153
  keep_end = min(TRUNCATE_KEEP_END, available - keep_start)
154
- marker = _choose_marker(
155
- max(0, original_length - (keep_start + keep_end)), max_chars
156
- )
154
+ marker = _choose_marker(max(0, original_length - (keep_start + keep_end)), max_chars)
157
155
 
158
156
  available = max(0, max_chars - len(marker))
159
157
  # Ensure kept sections fit the final budget; trim end first, then start if needed.
@@ -0,0 +1,17 @@
1
+ """Prompt helpers for interactive input."""
2
+
3
+ from getpass import getpass
4
+
5
+
6
+ def prompt_secret(prompt_text: str, prompt_suffix: str = ": ") -> str:
7
+ """Prompt for sensitive input, masking characters when possible.
8
+
9
+ Falls back to getpass (no echo) if prompt_toolkit is unavailable.
10
+ """
11
+ full_prompt = f"{prompt_text}{prompt_suffix}"
12
+ try:
13
+ from prompt_toolkit import prompt as pt_prompt
14
+
15
+ return pt_prompt(full_prompt, is_password=True)
16
+ except Exception:
17
+ return getpass(full_prompt)
@@ -17,6 +17,7 @@ class ModelUsage:
17
17
  cache_creation_input_tokens: int = 0
18
18
  requests: int = 0
19
19
  duration_ms: float = 0.0
20
+ cost_usd: float = 0.0
20
21
 
21
22
 
22
23
  @dataclass
@@ -49,6 +50,10 @@ class SessionUsage:
49
50
  def total_duration_ms(self) -> float:
50
51
  return sum(usage.duration_ms for usage in self.models.values())
51
52
 
53
+ @property
54
+ def total_cost_usd(self) -> float:
55
+ return sum(usage.cost_usd for usage in self.models.values())
56
+
52
57
 
53
58
  _SESSION_USAGE = SessionUsage()
54
59
 
@@ -76,6 +81,7 @@ def record_usage(
76
81
  cache_read_input_tokens: int = 0,
77
82
  cache_creation_input_tokens: int = 0,
78
83
  duration_ms: float = 0.0,
84
+ cost_usd: float = 0.0,
79
85
  ) -> None:
80
86
  """Record a single model invocation."""
81
87
  global _SESSION_USAGE
@@ -88,6 +94,7 @@ def record_usage(
88
94
  usage.cache_creation_input_tokens += _as_int(cache_creation_input_tokens)
89
95
  usage.duration_ms += float(duration_ms) if duration_ms and duration_ms > 0 else 0.0
90
96
  usage.requests += 1
97
+ usage.cost_usd += float(cost_usd) if cost_usd and cost_usd > 0 else 0.0
91
98
 
92
99
 
93
100
  def get_session_usage() -> SessionUsage: