shell-prompter 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompter/__init__.py +21 -0
- prompter/__main__.py +10 -0
- prompter/agent.py +205 -0
- prompter/cli.py +458 -0
- prompter/colors.py +47 -0
- prompter/config.py +138 -0
- prompter/constants.py +17 -0
- prompter/keys.py +59 -0
- prompter/prompts.py +83 -0
- prompter/providers/__init__.py +47 -0
- prompter/providers/anthropic_provider.py +172 -0
- prompter/providers/base.py +266 -0
- prompter/providers/gemini_provider.py +205 -0
- prompter/providers/openai_provider.py +183 -0
- prompter/risk.py +110 -0
- prompter/shell.py +125 -0
- prompter/ui.py +256 -0
- shell_prompter-0.1.0.dist-info/METADATA +289 -0
- shell_prompter-0.1.0.dist-info/RECORD +23 -0
- shell_prompter-0.1.0.dist-info/WHEEL +5 -0
- shell_prompter-0.1.0.dist-info/entry_points.txt +2 -0
- shell_prompter-0.1.0.dist-info/licenses/LICENSE +21 -0
- shell_prompter-0.1.0.dist-info/top_level.txt +1 -0
prompter/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""prompter is a natural-language shell agent powered by Claude.
|
|
2
|
+
|
|
3
|
+
You describe what you want in plain English. prompter figures out the shell
|
|
4
|
+
commands, shows them to you with a risk rating, and runs them after your
|
|
5
|
+
approval for anything non-trivial. It keeps track of the working directory as
|
|
6
|
+
it goes, just like a real shell session.
|
|
7
|
+
|
|
8
|
+
prompter "make a folder called scratch, cd into it, then run claude"
|
|
9
|
+
prompter "download uv if it isn't installed, then check the version"
|
|
10
|
+
prompter (run with no goal for the interactive REPL)
|
|
11
|
+
|
|
12
|
+
The agent's workspace is your shell, not a single repo. Claude is just one of
|
|
13
|
+
the tools it knows how to launch.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from .cli import main
|
|
19
|
+
|
|
20
|
+
__all__ = ["main"]
|
|
21
|
+
__version__ = "0.1.0"
|
prompter/__main__.py
ADDED
prompter/agent.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
"""The agent loop: drive the model's tool calls, gate them by risk, run them,
|
|
2
|
+
and keep the conversation and consecutive-failure state.
|
|
3
|
+
|
|
4
|
+
The loop is provider-neutral. It speaks only the types in providers.base and
|
|
5
|
+
calls one method, provider.complete(). Swapping Anthropic for OpenAI or Gemini
|
|
6
|
+
changes nothing here.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
|
|
13
|
+
from .config import ApprovalMode, Config
|
|
14
|
+
from .prompts import build_system_texts
|
|
15
|
+
from .providers.base import (
|
|
16
|
+
AssistantMessage,
|
|
17
|
+
AssistantTurn,
|
|
18
|
+
ModelProvider,
|
|
19
|
+
ToolInvocation,
|
|
20
|
+
ToolResult,
|
|
21
|
+
ToolResultsMessage,
|
|
22
|
+
UserMessage,
|
|
23
|
+
)
|
|
24
|
+
from .risk import RiskAssessment, RiskTier, classify
|
|
25
|
+
from .shell import CommandResult, Shell, looks_interactive
|
|
26
|
+
from .ui import Console, Decision
|
|
27
|
+
|
|
28
|
+
_DECLINED_MESSAGE = (
|
|
29
|
+
"User declined to run this command. "
|
|
30
|
+
"Suggest an alternative or ask what they'd prefer."
|
|
31
|
+
)
|
|
32
|
+
_FORCE_STOP_TEMPLATE = (
|
|
33
|
+
"{count} commands have failed in a row, reaching the max_fix_attempts limit "
|
|
34
|
+
"({limit}). Stop running commands now. Briefly explain what's going wrong "
|
|
35
|
+
"and what the user could try."
|
|
36
|
+
)
|
|
37
|
+
_PAYLOAD_TEMPLATE = (
|
|
38
|
+
"exit_code: {exit_code}\n"
|
|
39
|
+
"cwd: {cwd}\n"
|
|
40
|
+
"stdout:\n{stdout}\n"
|
|
41
|
+
"stderr:\n{stderr}"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class QuitRequested(Exception):
|
|
46
|
+
"""Raised when the user chooses 'quit' at a confirmation prompt."""
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class CommandRequest:
|
|
51
|
+
command: str
|
|
52
|
+
explanation: str
|
|
53
|
+
interactive: bool
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def from_invocation(cls, invocation: ToolInvocation) -> CommandRequest:
|
|
57
|
+
return cls(
|
|
58
|
+
command=invocation.command,
|
|
59
|
+
explanation=invocation.explanation,
|
|
60
|
+
interactive=invocation.interactive or looks_interactive(invocation.command),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class ToolOutcome:
|
|
66
|
+
"""Result of running one tool call.
|
|
67
|
+
|
|
68
|
+
is_failure is True only when a command actually ran and exited non-zero.
|
|
69
|
+
A user decline is an error to report but not a failed fix attempt.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
text: str
|
|
73
|
+
is_error: bool
|
|
74
|
+
is_failure: bool
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class Conversation:
|
|
78
|
+
"""The running neutral history handed to the provider each turn."""
|
|
79
|
+
|
|
80
|
+
def __init__(self):
|
|
81
|
+
self.history: list = []
|
|
82
|
+
|
|
83
|
+
def add_user(self, text: str) -> None:
|
|
84
|
+
self.history.append(UserMessage(text))
|
|
85
|
+
|
|
86
|
+
def add_assistant(self, turn: AssistantTurn) -> None:
|
|
87
|
+
self.history.append(AssistantMessage(turn.text, turn.tool_calls))
|
|
88
|
+
|
|
89
|
+
def add_tool_results(self, results: list[ToolResult]) -> None:
|
|
90
|
+
self.history.append(ToolResultsMessage(results))
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class Agent:
|
|
94
|
+
def __init__(self, provider: ModelProvider, shell: Shell, console: Console,
|
|
95
|
+
config: Config, mode: ApprovalMode):
|
|
96
|
+
self.provider = provider
|
|
97
|
+
self.shell = shell
|
|
98
|
+
self.console = console
|
|
99
|
+
self.config = config
|
|
100
|
+
self.mode = mode
|
|
101
|
+
self.conversation = Conversation()
|
|
102
|
+
self._approve_all = False
|
|
103
|
+
self.consecutive_failures = 0
|
|
104
|
+
self._force_stop = False
|
|
105
|
+
|
|
106
|
+
def run_turn(self, user_text: str) -> None:
|
|
107
|
+
self.consecutive_failures = 0
|
|
108
|
+
self._force_stop = False
|
|
109
|
+
self.conversation.add_user(user_text)
|
|
110
|
+
another_round = True
|
|
111
|
+
while another_round:
|
|
112
|
+
another_round = self._run_round()
|
|
113
|
+
|
|
114
|
+
def _run_round(self) -> bool:
|
|
115
|
+
"""Run one model turn and return whether another round is needed."""
|
|
116
|
+
turn = self._complete()
|
|
117
|
+
self.conversation.add_assistant(turn)
|
|
118
|
+
if not turn.tool_calls:
|
|
119
|
+
return False
|
|
120
|
+
results = self._process_tool_calls(turn.tool_calls)
|
|
121
|
+
self.conversation.add_tool_results(results)
|
|
122
|
+
self._maybe_force_stop()
|
|
123
|
+
return True
|
|
124
|
+
|
|
125
|
+
def _complete(self) -> AssistantTurn:
|
|
126
|
+
system_texts = build_system_texts(self.config, self.shell.cwd)
|
|
127
|
+
self.console.begin_stream()
|
|
128
|
+
turn = self.provider.complete(
|
|
129
|
+
self.conversation.history,
|
|
130
|
+
system_texts,
|
|
131
|
+
self._force_stop,
|
|
132
|
+
self.console.stream_text,
|
|
133
|
+
)
|
|
134
|
+
self.console.end_stream()
|
|
135
|
+
return turn
|
|
136
|
+
|
|
137
|
+
def _process_tool_calls(self, tool_calls: list[ToolInvocation]) -> list[ToolResult]:
|
|
138
|
+
results = []
|
|
139
|
+
for invocation in tool_calls:
|
|
140
|
+
outcome = self._execute(CommandRequest.from_invocation(invocation))
|
|
141
|
+
self.consecutive_failures = (
|
|
142
|
+
self.consecutive_failures + 1 if outcome.is_failure else 0
|
|
143
|
+
)
|
|
144
|
+
results.append(
|
|
145
|
+
ToolResult(invocation.call_id, outcome.text, outcome.is_error)
|
|
146
|
+
)
|
|
147
|
+
return results
|
|
148
|
+
|
|
149
|
+
def _maybe_force_stop(self) -> None:
|
|
150
|
+
"""Bound the self-repair loop once too many commands fail in a row."""
|
|
151
|
+
limit = self.config.max_fix_attempts
|
|
152
|
+
if self.consecutive_failures < limit or self._force_stop:
|
|
153
|
+
return
|
|
154
|
+
message = _FORCE_STOP_TEMPLATE.format(
|
|
155
|
+
count=self.consecutive_failures, limit=limit)
|
|
156
|
+
self.conversation.add_user(message)
|
|
157
|
+
self._force_stop = True
|
|
158
|
+
self.console.force_stop_notice(limit)
|
|
159
|
+
|
|
160
|
+
def _execute(self, request: CommandRequest) -> ToolOutcome:
|
|
161
|
+
assessment = classify(request.command)
|
|
162
|
+
if not self._authorize(request, assessment):
|
|
163
|
+
return ToolOutcome(_DECLINED_MESSAGE, is_error=True, is_failure=False)
|
|
164
|
+
|
|
165
|
+
result = self.shell.run(request.command, interactive=request.interactive)
|
|
166
|
+
if result.cwd_changed:
|
|
167
|
+
self.console.cwd_change(self.shell.cwd)
|
|
168
|
+
if not request.interactive:
|
|
169
|
+
self.console.command_output(result)
|
|
170
|
+
return ToolOutcome(
|
|
171
|
+
self._format_payload(result),
|
|
172
|
+
is_error=result.failed,
|
|
173
|
+
is_failure=result.failed,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
def _authorize(self, request: CommandRequest,
|
|
177
|
+
assessment: RiskAssessment) -> bool:
|
|
178
|
+
if self._auto_approves(assessment.tier):
|
|
179
|
+
self.console.auto_run(assessment, request.command)
|
|
180
|
+
return True
|
|
181
|
+
decision = self.console.confirm(
|
|
182
|
+
assessment, request.command, request.explanation
|
|
183
|
+
)
|
|
184
|
+
if decision == Decision.QUIT:
|
|
185
|
+
raise QuitRequested
|
|
186
|
+
if decision == Decision.SKIP:
|
|
187
|
+
return False
|
|
188
|
+
if decision == Decision.ALL:
|
|
189
|
+
self._approve_all = True
|
|
190
|
+
return True
|
|
191
|
+
|
|
192
|
+
def _auto_approves(self, tier: RiskTier) -> bool:
|
|
193
|
+
if self._approve_all or self.mode == ApprovalMode.YOLO:
|
|
194
|
+
return True
|
|
195
|
+
if self.mode == ApprovalMode.ASK_ALL:
|
|
196
|
+
return False
|
|
197
|
+
return tier == RiskTier.SAFE
|
|
198
|
+
|
|
199
|
+
def _format_payload(self, result: CommandResult) -> str:
|
|
200
|
+
return _PAYLOAD_TEMPLATE.format(
|
|
201
|
+
exit_code=result.exit_code,
|
|
202
|
+
cwd=self.shell.cwd,
|
|
203
|
+
stdout=result.stdout,
|
|
204
|
+
stderr=result.stderr,
|
|
205
|
+
)
|
prompter/cli.py
ADDED
|
@@ -0,0 +1,458 @@
|
|
|
1
|
+
"""Command-line entry point: a command dispatcher with actionable errors.
|
|
2
|
+
|
|
3
|
+
Management is done with subcommands (keys, use, status, config). Flags only
|
|
4
|
+
modify a single run. No command ever opens an interactive prompt. Every failure
|
|
5
|
+
prints a problem with the exact command to fix it.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import argparse
|
|
11
|
+
import difflib
|
|
12
|
+
import os
|
|
13
|
+
import sys
|
|
14
|
+
|
|
15
|
+
from .agent import Agent, QuitRequested
|
|
16
|
+
from .config import (
|
|
17
|
+
ApprovalMode,
|
|
18
|
+
CONFIG_PATH,
|
|
19
|
+
Config,
|
|
20
|
+
PROGRAM_NAME,
|
|
21
|
+
PROVIDER_ALIASES,
|
|
22
|
+
PROVIDER_OPENAI,
|
|
23
|
+
default_api_key_env,
|
|
24
|
+
default_model_for,
|
|
25
|
+
load_config,
|
|
26
|
+
normalize_provider,
|
|
27
|
+
save_config,
|
|
28
|
+
)
|
|
29
|
+
from .constants import COMMA_SPACE, EMPTY, SPACE
|
|
30
|
+
from .keys import KEYS_PATH, clear_key, set_key, stored_key
|
|
31
|
+
from .providers import (
|
|
32
|
+
ModelProvider,
|
|
33
|
+
ProviderAuthError,
|
|
34
|
+
ProviderError,
|
|
35
|
+
create_provider,
|
|
36
|
+
known_providers,
|
|
37
|
+
)
|
|
38
|
+
from .shell import Shell
|
|
39
|
+
from .ui import Console
|
|
40
|
+
|
|
41
|
+
OK_EXIT_CODE = 0
|
|
42
|
+
ERROR_EXIT_CODE = 1
|
|
43
|
+
SIGINT_EXIT_CODE = 130
|
|
44
|
+
|
|
45
|
+
_CMD_KEYS = "keys"
|
|
46
|
+
_CMD_USE = "use"
|
|
47
|
+
_CMD_STATUS = "status"
|
|
48
|
+
_CMD_CONFIG = "config"
|
|
49
|
+
_CMD_RUN = "run"
|
|
50
|
+
_CMD_HELP = "help"
|
|
51
|
+
_HELP_FLAGS = ("-h", "--help")
|
|
52
|
+
|
|
53
|
+
_KEYS_LIST = "list"
|
|
54
|
+
_KEYS_ADD = "add"
|
|
55
|
+
_KEYS_REMOVE = ("remove", "rm", "clear")
|
|
56
|
+
|
|
57
|
+
_EXIT_WORDS = {"exit", "quit", ":q"}
|
|
58
|
+
|
|
59
|
+
_ARG_PROMPT = "prompt"
|
|
60
|
+
_ARG_NARGS_ANY = "*"
|
|
61
|
+
_FLAG_PROVIDER = "--provider"
|
|
62
|
+
_FLAG_MODEL = "--model"
|
|
63
|
+
_FLAG_BASE_URL = "--base-url"
|
|
64
|
+
_FLAG_WORKSPACE = "--workspace"
|
|
65
|
+
_FLAG_MAX_FIX = "--max-fix"
|
|
66
|
+
_FLAG_ASK_ALL = "--ask-all"
|
|
67
|
+
_FLAG_YOLO = "--yolo"
|
|
68
|
+
_ACTION_STORE_TRUE = "store_true"
|
|
69
|
+
|
|
70
|
+
_LABEL_ADD_IT = "add it"
|
|
71
|
+
_LABEL_OR_EXPORT = "or export"
|
|
72
|
+
_LABEL_DID_YOU_MEAN = "did you mean"
|
|
73
|
+
_LABEL_AVAILABLE = "available"
|
|
74
|
+
_LABEL_SWITCH_TO_IT = "switch to it"
|
|
75
|
+
_LABEL_FOR_ONE_RUN = "for one run"
|
|
76
|
+
_LABEL_SET_A_MODEL = "set a model"
|
|
77
|
+
_LABEL_USAGE = "usage"
|
|
78
|
+
_LABEL_EXAMPLE = "example"
|
|
79
|
+
|
|
80
|
+
_FIELD_PROVIDER = "provider"
|
|
81
|
+
_FIELD_MODEL = "model"
|
|
82
|
+
_FIELD_WORKSPACE = "workspace"
|
|
83
|
+
_FIELD_MAX_FIX = "max-fix"
|
|
84
|
+
_FIELD_CONFIG = "config"
|
|
85
|
+
|
|
86
|
+
_BASE_URL_IGNORED = (
|
|
87
|
+
"Note: base_url only applies to the openai provider. Ignoring it for {provider}."
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
_TITLE_NO_KEY = "No API key for {provider}"
|
|
91
|
+
_CMD_ADD_KEY = "prompter keys add {provider} <key>"
|
|
92
|
+
_ENV_EXPORT = "{env}=<key>"
|
|
93
|
+
|
|
94
|
+
_TITLE_UNKNOWN_PROVIDER = 'Unknown provider "{name}"'
|
|
95
|
+
_HINT_USE = "prompter use {provider}"
|
|
96
|
+
|
|
97
|
+
_TITLE_MODEL_IS_PROVIDER = '"{value}" is a provider, not a model'
|
|
98
|
+
_CMD_RUN_PROVIDER = 'prompter --provider {provider} "..."'
|
|
99
|
+
|
|
100
|
+
_TITLE_BAD_MAX_FIX = "--max-fix must be at least 1 (got {value})"
|
|
101
|
+
|
|
102
|
+
_MODEL_WORD = "model"
|
|
103
|
+
_TITLE_NO_MODEL = '{provider} has no model "{model}"'
|
|
104
|
+
_CMD_USE_MODEL = "prompter use {provider} <model>"
|
|
105
|
+
_TITLE_PROVIDER_FAILED = "{provider} request failed"
|
|
106
|
+
_MODEL_ERROR_MARKERS = (
|
|
107
|
+
"not found", "does not exist", "is not supported", "no model",
|
|
108
|
+
"invalid model", "not a valid model",
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
_TITLE_UNKNOWN_KEYS_ACTION = 'Unknown "keys" command: "{action}"'
|
|
112
|
+
_USAGE_KEYS = "prompter keys [list | add <provider> <key> | remove <provider>]"
|
|
113
|
+
_TITLE_KEYS_ADD_USAGE = "Usage: prompter keys add <provider> <key>"
|
|
114
|
+
_EXAMPLE_KEYS_ADD = "prompter keys add gemini AIza..."
|
|
115
|
+
_TITLE_KEYS_REMOVE_USAGE = "Usage: prompter keys remove <provider>"
|
|
116
|
+
_KEY_SAVED = "Saved the {provider} key to {path}"
|
|
117
|
+
_KEY_REMOVED = "Removed the stored {provider} key"
|
|
118
|
+
_NO_STORED_KEY = "No stored key for {provider}"
|
|
119
|
+
|
|
120
|
+
_SOURCE_ENV = "set via {env}"
|
|
121
|
+
_SOURCE_STORED = "stored ({masked})"
|
|
122
|
+
_SOURCE_NOT_SET = "not set"
|
|
123
|
+
_MASK_TEMPLATE = "...{tail}"
|
|
124
|
+
_MASK_FALLBACK = "set"
|
|
125
|
+
_MASK_TAIL_LENGTH = 4
|
|
126
|
+
|
|
127
|
+
_TITLE_USE_USAGE = "Usage: prompter use <provider> [model]"
|
|
128
|
+
_EXAMPLE_USE = "prompter use gemini"
|
|
129
|
+
_DEFAULT_SET = "Default provider set to {provider} (model: {model})"
|
|
130
|
+
|
|
131
|
+
_KEYS_HEADER = "keys:"
|
|
132
|
+
|
|
133
|
+
_HELP_TEXT = """\
|
|
134
|
+
prompter is a natural-language shell agent.
|
|
135
|
+
|
|
136
|
+
Run:
|
|
137
|
+
prompter "<goal>" run a one-off goal
|
|
138
|
+
prompter interactive chat (REPL)
|
|
139
|
+
|
|
140
|
+
Run flags (modify one run):
|
|
141
|
+
--provider NAME anthropic, openai, or gemini
|
|
142
|
+
--model ID model override
|
|
143
|
+
--base-url URL OpenAI-compatible endpoint (Groq, OpenRouter)
|
|
144
|
+
--workspace PATH where new projects go
|
|
145
|
+
--max-fix N failures in a row before stopping
|
|
146
|
+
--ask-all confirm every command
|
|
147
|
+
--yolo run everything with no confirmation
|
|
148
|
+
|
|
149
|
+
Manage:
|
|
150
|
+
prompter keys add <provider> <key> store an API key
|
|
151
|
+
prompter keys list show key status
|
|
152
|
+
prompter keys remove <provider> delete a stored key
|
|
153
|
+
prompter use <provider> [model] set your default provider/model
|
|
154
|
+
prompter status show the current setup
|
|
155
|
+
prompter config print the config file path
|
|
156
|
+
prompter help show this help
|
|
157
|
+
|
|
158
|
+
Providers: anthropic, openai, gemini
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _missing_key_problem(config: Config):
|
|
163
|
+
return (
|
|
164
|
+
_TITLE_NO_KEY.format(provider=config.provider),
|
|
165
|
+
[(_LABEL_ADD_IT, _CMD_ADD_KEY.format(provider=config.provider)),
|
|
166
|
+
(_LABEL_OR_EXPORT, _ENV_EXPORT.format(env=config.key_env))],
|
|
167
|
+
None,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _unknown_provider_problem(name: str):
|
|
172
|
+
"""Build the unknown-provider problem, matching aliases for a suggestion.
|
|
173
|
+
|
|
174
|
+
Candidates include the aliases (claude, gpt, google) so a near miss like
|
|
175
|
+
"claud" suggests anthropic, normalized back to its canonical name.
|
|
176
|
+
"""
|
|
177
|
+
known = known_providers()
|
|
178
|
+
hints = []
|
|
179
|
+
candidates = list(known) + list(PROVIDER_ALIASES)
|
|
180
|
+
match = difflib.get_close_matches(name.lower(), candidates, n=1)
|
|
181
|
+
if match:
|
|
182
|
+
hints.append((_LABEL_DID_YOU_MEAN,
|
|
183
|
+
_HINT_USE.format(provider=normalize_provider(match[0]))))
|
|
184
|
+
hints.append((_LABEL_AVAILABLE, COMMA_SPACE.join(known)))
|
|
185
|
+
return (_TITLE_UNKNOWN_PROVIDER.format(name=name), hints, None)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _model_is_provider_problem(value: str, provider: str):
|
|
189
|
+
return (
|
|
190
|
+
_TITLE_MODEL_IS_PROVIDER.format(value=value),
|
|
191
|
+
[(_LABEL_SWITCH_TO_IT, _HINT_USE.format(provider=provider)),
|
|
192
|
+
(_LABEL_FOR_ONE_RUN, _CMD_RUN_PROVIDER.format(provider=provider))],
|
|
193
|
+
None,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _looks_like_model_error(message: str) -> bool:
|
|
198
|
+
low = message.lower()
|
|
199
|
+
return _MODEL_WORD in low and any(m in low for m in _MODEL_ERROR_MARKERS)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _provider_request_problem(config: Config, exc: ProviderError):
|
|
203
|
+
detail = str(exc)
|
|
204
|
+
if _looks_like_model_error(detail):
|
|
205
|
+
return (
|
|
206
|
+
_TITLE_NO_MODEL.format(provider=config.provider, model=config.resolved_model),
|
|
207
|
+
[(_LABEL_SET_A_MODEL, _CMD_USE_MODEL.format(provider=config.provider))],
|
|
208
|
+
None,
|
|
209
|
+
)
|
|
210
|
+
return (_TITLE_PROVIDER_FAILED.format(provider=config.provider), [], detail)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class _CommandError(Exception):
|
|
214
|
+
def __init__(self, title, hints=(), detail=None):
|
|
215
|
+
super().__init__(title)
|
|
216
|
+
self.title = title
|
|
217
|
+
self.hints = hints
|
|
218
|
+
self.detail = detail
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _fail(console: Console, title, hints=(), detail=None) -> int:
|
|
222
|
+
console.problem(title, hints, detail)
|
|
223
|
+
return ERROR_EXIT_CODE
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _mask(key: str) -> str:
|
|
227
|
+
if len(key) >= _MASK_TAIL_LENGTH:
|
|
228
|
+
return _MASK_TEMPLATE.format(tail=key[-_MASK_TAIL_LENGTH:])
|
|
229
|
+
return _MASK_FALLBACK
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def _run_parser() -> argparse.ArgumentParser:
|
|
233
|
+
parser = argparse.ArgumentParser(prog=PROGRAM_NAME, add_help=False)
|
|
234
|
+
parser.add_argument(_ARG_PROMPT, nargs=_ARG_NARGS_ANY)
|
|
235
|
+
parser.add_argument(_FLAG_PROVIDER)
|
|
236
|
+
parser.add_argument(_FLAG_MODEL, default=None)
|
|
237
|
+
parser.add_argument(_FLAG_BASE_URL)
|
|
238
|
+
parser.add_argument(_FLAG_WORKSPACE)
|
|
239
|
+
parser.add_argument(_FLAG_MAX_FIX, type=int, default=None)
|
|
240
|
+
parser.add_argument(_FLAG_ASK_ALL, action=_ACTION_STORE_TRUE)
|
|
241
|
+
parser.add_argument(_FLAG_YOLO, action=_ACTION_STORE_TRUE)
|
|
242
|
+
return parser
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def _reject_provider_as_model(args) -> None:
|
|
246
|
+
if not args.model:
|
|
247
|
+
return
|
|
248
|
+
provider = normalize_provider(args.model)
|
|
249
|
+
if provider in known_providers():
|
|
250
|
+
raise _CommandError(*_model_is_provider_problem(args.model, provider))
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def _apply_overrides(args, config: Config) -> None:
|
|
254
|
+
"""Apply CLI flag overrides for this run, then resolve the model.
|
|
255
|
+
|
|
256
|
+
A model pinned in the config belongs to the provider it was set for.
|
|
257
|
+
Switching providers falls back to the new provider's default model.
|
|
258
|
+
"""
|
|
259
|
+
config_provider = normalize_provider(config.provider)
|
|
260
|
+
if args.provider:
|
|
261
|
+
config.provider = args.provider
|
|
262
|
+
config.provider = normalize_provider(config.provider)
|
|
263
|
+
if args.base_url:
|
|
264
|
+
config.base_url = args.base_url
|
|
265
|
+
if args.workspace:
|
|
266
|
+
config.default_workspace = args.workspace
|
|
267
|
+
if args.max_fix is not None:
|
|
268
|
+
config.max_fix_attempts = args.max_fix
|
|
269
|
+
pinned = config.model if config_provider == config.provider else EMPTY
|
|
270
|
+
config.model = args.model or pinned or default_model_for(config.provider)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def _resolve_mode(args, config: Config) -> ApprovalMode:
|
|
274
|
+
if args.yolo:
|
|
275
|
+
return ApprovalMode.YOLO
|
|
276
|
+
if args.ask_all or not config.auto_approve_safe:
|
|
277
|
+
return ApprovalMode.ASK_ALL
|
|
278
|
+
return ApprovalMode.SMART
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def _validate_config(config: Config, console: Console) -> None:
|
|
282
|
+
if config.max_fix_attempts < 1:
|
|
283
|
+
raise _CommandError(_TITLE_BAD_MAX_FIX.format(value=config.max_fix_attempts))
|
|
284
|
+
if config.base_url and config.provider != PROVIDER_OPENAI:
|
|
285
|
+
console.note(_BASE_URL_IGNORED.format(provider=config.provider))
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def _create_provider_checked(config: Config) -> ModelProvider:
|
|
289
|
+
if config.provider not in known_providers():
|
|
290
|
+
raise _CommandError(*_unknown_provider_problem(config.provider))
|
|
291
|
+
try:
|
|
292
|
+
return create_provider(config)
|
|
293
|
+
except ProviderError:
|
|
294
|
+
raise _CommandError(*_missing_key_problem(config))
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def _prepare(args, console: Console):
|
|
298
|
+
config = load_config()
|
|
299
|
+
_reject_provider_as_model(args)
|
|
300
|
+
_apply_overrides(args, config)
|
|
301
|
+
_validate_config(config, console)
|
|
302
|
+
provider = _create_provider_checked(config)
|
|
303
|
+
mode = _resolve_mode(args, config)
|
|
304
|
+
return Agent(provider, Shell(), console, config, mode), config, mode
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def _dispatch(agent: Agent, goal: str) -> None:
|
|
308
|
+
if goal:
|
|
309
|
+
agent.run_turn(goal)
|
|
310
|
+
else:
|
|
311
|
+
_repl(agent, agent.console)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _repl(agent: Agent, console: Console) -> None:
|
|
315
|
+
console.repl_intro(agent.shell.cwd)
|
|
316
|
+
while True:
|
|
317
|
+
try:
|
|
318
|
+
line = console.repl_prompt(agent.shell.cwd)
|
|
319
|
+
except (EOFError, KeyboardInterrupt):
|
|
320
|
+
print()
|
|
321
|
+
return
|
|
322
|
+
if line.lower() in _EXIT_WORDS:
|
|
323
|
+
return
|
|
324
|
+
if not line:
|
|
325
|
+
continue
|
|
326
|
+
agent.run_turn(line)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def _run_agent(agent: Agent, config: Config, console: Console, goal: str) -> int:
|
|
330
|
+
try:
|
|
331
|
+
_dispatch(agent, goal)
|
|
332
|
+
except (KeyboardInterrupt, QuitRequested):
|
|
333
|
+
console.stopped()
|
|
334
|
+
return SIGINT_EXIT_CODE
|
|
335
|
+
except ProviderAuthError:
|
|
336
|
+
return _fail(console, *_missing_key_problem(config))
|
|
337
|
+
except ProviderError as exc:
|
|
338
|
+
return _fail(console, *_provider_request_problem(config, exc))
|
|
339
|
+
return OK_EXIT_CODE
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def cmd_run(argv: list[str], console: Console) -> int:
|
|
343
|
+
args = _run_parser().parse_args(argv)
|
|
344
|
+
try:
|
|
345
|
+
agent, config, mode = _prepare(args, console)
|
|
346
|
+
except _CommandError as e:
|
|
347
|
+
return _fail(console, e.title, e.hints, e.detail)
|
|
348
|
+
console.banner(config, mode)
|
|
349
|
+
return _run_agent(agent, config, console, SPACE.join(args.prompt).strip())
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def _keys_list(console: Console) -> int:
|
|
353
|
+
for provider in known_providers():
|
|
354
|
+
env = default_api_key_env(provider)
|
|
355
|
+
stored = stored_key(provider)
|
|
356
|
+
if env and os.environ.get(env):
|
|
357
|
+
console.key_status(provider, _SOURCE_ENV.format(env=env), True)
|
|
358
|
+
elif stored:
|
|
359
|
+
console.key_status(
|
|
360
|
+
provider, _SOURCE_STORED.format(masked=_mask(stored)), True)
|
|
361
|
+
else:
|
|
362
|
+
console.key_status(provider, _SOURCE_NOT_SET, False)
|
|
363
|
+
return OK_EXIT_CODE
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def _keys_add(rest: list[str], console: Console) -> int:
|
|
367
|
+
if len(rest) < 2:
|
|
368
|
+
return _fail(console, _TITLE_KEYS_ADD_USAGE, [(_LABEL_EXAMPLE, _EXAMPLE_KEYS_ADD)])
|
|
369
|
+
provider = normalize_provider(rest[0])
|
|
370
|
+
if provider not in known_providers():
|
|
371
|
+
return _fail(console, *_unknown_provider_problem(rest[0]))
|
|
372
|
+
set_key(provider, rest[1])
|
|
373
|
+
console.success(_KEY_SAVED.format(provider=provider, path=KEYS_PATH))
|
|
374
|
+
return OK_EXIT_CODE
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def _keys_remove(rest: list[str], console: Console) -> int:
|
|
378
|
+
if not rest:
|
|
379
|
+
return _fail(console, _TITLE_KEYS_REMOVE_USAGE)
|
|
380
|
+
provider = normalize_provider(rest[0])
|
|
381
|
+
if clear_key(provider):
|
|
382
|
+
console.success(_KEY_REMOVED.format(provider=provider))
|
|
383
|
+
else:
|
|
384
|
+
console.info(_NO_STORED_KEY.format(provider=provider))
|
|
385
|
+
return OK_EXIT_CODE
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def cmd_keys(argv: list[str], console: Console) -> int:
|
|
389
|
+
action = argv[0] if argv else _KEYS_LIST
|
|
390
|
+
rest = argv[1:]
|
|
391
|
+
if action == _KEYS_LIST:
|
|
392
|
+
return _keys_list(console)
|
|
393
|
+
if action == _KEYS_ADD:
|
|
394
|
+
return _keys_add(rest, console)
|
|
395
|
+
if action in _KEYS_REMOVE:
|
|
396
|
+
return _keys_remove(rest, console)
|
|
397
|
+
return _fail(console, _TITLE_UNKNOWN_KEYS_ACTION.format(action=action),
|
|
398
|
+
[(_LABEL_USAGE, _USAGE_KEYS)])
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def cmd_use(argv: list[str], console: Console) -> int:
|
|
402
|
+
if not argv:
|
|
403
|
+
return _fail(console, _TITLE_USE_USAGE, [(_LABEL_EXAMPLE, _EXAMPLE_USE)])
|
|
404
|
+
provider = normalize_provider(argv[0])
|
|
405
|
+
if provider not in known_providers():
|
|
406
|
+
return _fail(console, *_unknown_provider_problem(argv[0]))
|
|
407
|
+
config = load_config()
|
|
408
|
+
config.provider = provider
|
|
409
|
+
config.model = argv[1] if len(argv) > 1 else EMPTY
|
|
410
|
+
save_config(config)
|
|
411
|
+
console.success(_DEFAULT_SET.format(provider=provider, model=config.resolved_model))
|
|
412
|
+
return OK_EXIT_CODE
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def cmd_status(argv: list[str], console: Console) -> int:
|
|
416
|
+
config = load_config()
|
|
417
|
+
console.field(_FIELD_PROVIDER, config.provider)
|
|
418
|
+
console.field(_FIELD_MODEL, config.resolved_model)
|
|
419
|
+
console.field(_FIELD_WORKSPACE, config.workspace_path)
|
|
420
|
+
console.field(_FIELD_MAX_FIX, str(config.max_fix_attempts))
|
|
421
|
+
console.field(_FIELD_CONFIG, CONFIG_PATH)
|
|
422
|
+
console.info(_KEYS_HEADER)
|
|
423
|
+
return _keys_list(console)
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def cmd_config(argv: list[str], console: Console) -> int:
|
|
427
|
+
load_config()
|
|
428
|
+
console.info(CONFIG_PATH)
|
|
429
|
+
return OK_EXIT_CODE
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def cmd_help(console: Console) -> int:
|
|
433
|
+
console.info(_HELP_TEXT)
|
|
434
|
+
return OK_EXIT_CODE
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
_COMMANDS = {
|
|
438
|
+
_CMD_KEYS: cmd_keys,
|
|
439
|
+
_CMD_USE: cmd_use,
|
|
440
|
+
_CMD_STATUS: cmd_status,
|
|
441
|
+
_CMD_CONFIG: cmd_config,
|
|
442
|
+
}
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def main(argv: list[str] | None = None) -> int:
|
|
446
|
+
argv = list(sys.argv[1:] if argv is None else argv)
|
|
447
|
+
console = Console()
|
|
448
|
+
if not argv:
|
|
449
|
+
return cmd_run(argv, console)
|
|
450
|
+
head = argv[0]
|
|
451
|
+
if head in _HELP_FLAGS or head == _CMD_HELP:
|
|
452
|
+
return cmd_help(console)
|
|
453
|
+
if head == _CMD_RUN:
|
|
454
|
+
return cmd_run(argv[1:], console)
|
|
455
|
+
handler = _COMMANDS.get(head)
|
|
456
|
+
if handler is not None:
|
|
457
|
+
return handler(argv[1:], console)
|
|
458
|
+
return cmd_run(argv, console)
|