auzek 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- auzek/__init__.py +3 -0
- auzek/__main__.py +4 -0
- auzek/cli.py +201 -0
- auzek/config.py +92 -0
- auzek/graph.py +94 -0
- auzek/llm.py +183 -0
- auzek/memory/__init__.py +1 -0
- auzek/memory/plan_store.py +128 -0
- auzek/nodes/__init__.py +19 -0
- auzek/nodes/_util.py +17 -0
- auzek/nodes/approval.py +19 -0
- auzek/nodes/commit.py +28 -0
- auzek/nodes/context.py +24 -0
- auzek/nodes/execution.py +75 -0
- auzek/nodes/planning.py +102 -0
- auzek/nodes/recovery.py +84 -0
- auzek/nodes/report.py +36 -0
- auzek/nodes/verification.py +95 -0
- auzek/prompts.py +99 -0
- auzek/runtime.py +148 -0
- auzek/state.py +64 -0
- auzek/tools/__init__.py +39 -0
- auzek/tools/base.py +121 -0
- auzek/tools/filesystem.py +154 -0
- auzek/tools/git_tools.py +69 -0
- auzek/tools/search.py +75 -0
- auzek/tools/shell.py +59 -0
- auzek-0.1.0.dist-info/METADATA +220 -0
- auzek-0.1.0.dist-info/RECORD +32 -0
- auzek-0.1.0.dist-info/WHEEL +5 -0
- auzek-0.1.0.dist-info/entry_points.txt +2 -0
- auzek-0.1.0.dist-info/top_level.txt +1 -0
auzek/__init__.py
ADDED
auzek/__main__.py
ADDED
auzek/cli.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""Command-line entrypoint.
|
|
2
|
+
|
|
3
|
+
agent run "add retry logic to the API client" --provider groq
|
|
4
|
+
agent providers
|
|
5
|
+
agent plan-show
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import sys
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Optional
|
|
13
|
+
|
|
14
|
+
# Make output robust on Windows consoles (cp1252) so glyphs never crash a run.
|
|
15
|
+
for _stream in (sys.stdout, sys.stderr):
|
|
16
|
+
try:
|
|
17
|
+
_stream.reconfigure(encoding="utf-8", errors="replace") # type: ignore[attr-defined]
|
|
18
|
+
except (AttributeError, ValueError):
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
import typer
|
|
22
|
+
from dotenv import load_dotenv
|
|
23
|
+
from rich.console import Console
|
|
24
|
+
from rich.markdown import Markdown
|
|
25
|
+
from rich.panel import Panel
|
|
26
|
+
from rich.prompt import Confirm
|
|
27
|
+
from rich.table import Table
|
|
28
|
+
|
|
29
|
+
from .config import AgentConfig
|
|
30
|
+
from .graph import build_graph
|
|
31
|
+
from .llm import LLM, PROVIDERS, LLMConfigError, available_providers
|
|
32
|
+
from .memory.plan_store import Plan, PlanStore
|
|
33
|
+
from .runtime import Deps
|
|
34
|
+
from .state import new_state
|
|
35
|
+
from .tools import build_default_registry
|
|
36
|
+
from .tools.base import ToolContext
|
|
37
|
+
|
|
38
|
+
app = typer.Typer(add_completion=False, help="Autonomous coding agent.")
|
|
39
|
+
console = Console()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# --------------------------------------------------------------------- helpers
|
|
43
|
+
def _emit(msg: str) -> None:
|
|
44
|
+
style = "cyan" if msg.startswith("[phase]") else (
|
|
45
|
+
"green" if msg.startswith(("✓", " ✓")) else
|
|
46
|
+
"red" if msg.startswith(("✗", " ✗")) else "dim"
|
|
47
|
+
)
|
|
48
|
+
console.print(msg, style=style)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _build_deps(cfg: AgentConfig, api_key: Optional[str]) -> Deps:
|
|
52
|
+
llm = LLM(
|
|
53
|
+
cfg.provider,
|
|
54
|
+
cfg.model,
|
|
55
|
+
temperature=cfg.temperature,
|
|
56
|
+
max_tokens=cfg.max_tokens,
|
|
57
|
+
api_key=api_key,
|
|
58
|
+
)
|
|
59
|
+
tool_ctx = ToolContext(workspace=cfg.workspace, deny_globs=cfg.deny_globs)
|
|
60
|
+
registry = build_default_registry(tool_ctx)
|
|
61
|
+
plan_store = PlanStore(cfg.state_dir)
|
|
62
|
+
return Deps(
|
|
63
|
+
config=cfg,
|
|
64
|
+
llm=llm,
|
|
65
|
+
tools=registry,
|
|
66
|
+
tool_ctx=tool_ctx,
|
|
67
|
+
plan_store=plan_store,
|
|
68
|
+
emit=_emit,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _render_plan(plan: Plan) -> None:
|
|
73
|
+
table = Table(title="Proposed Plan", show_lines=False, header_style="bold")
|
|
74
|
+
table.add_column("#", justify="right", style="cyan", no_wrap=True)
|
|
75
|
+
table.add_column("Step")
|
|
76
|
+
table.add_column("Files", style="dim")
|
|
77
|
+
for s in plan.steps:
|
|
78
|
+
table.add_row(str(s.id), s.description, ", ".join(s.files) or "—")
|
|
79
|
+
console.print(table)
|
|
80
|
+
if plan.assumptions:
|
|
81
|
+
console.print(Panel("\n".join(f"• {a}" for a in plan.assumptions),
|
|
82
|
+
title="Assumptions", border_style="yellow"))
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# ------------------------------------------------------------------------ run
|
|
86
|
+
@app.command()
|
|
87
|
+
def run(
|
|
88
|
+
task: str = typer.Argument(..., help="The task/assignment in natural language."),
|
|
89
|
+
provider: Optional[str] = typer.Option(None, help=f"One of: {', '.join(PROVIDERS)}"),
|
|
90
|
+
model: Optional[str] = typer.Option(None, help="Model id (defaults per provider)."),
|
|
91
|
+
api_key: Optional[str] = typer.Option(None, help="API key (else read from env/.env)."),
|
|
92
|
+
workspace: Path = typer.Option(Path.cwd(), help="Repo to operate on."),
|
|
93
|
+
yes: bool = typer.Option(False, "--yes", "-y", help="Auto-approve the plan."),
|
|
94
|
+
no_approval: bool = typer.Option(False, help="Disable the approval gate entirely."),
|
|
95
|
+
max_steps: Optional[int] = typer.Option(None, help="Cap on plan steps executed."),
|
|
96
|
+
auto_commit: bool = typer.Option(False, help="git-commit after each successful step."),
|
|
97
|
+
temperature: Optional[float] = typer.Option(None, help="Sampling temperature."),
|
|
98
|
+
) -> None:
|
|
99
|
+
"""Plan and execute a coding task autonomously."""
|
|
100
|
+
load_dotenv(workspace / ".env")
|
|
101
|
+
load_dotenv() # also pick up CWD/home .env
|
|
102
|
+
|
|
103
|
+
overrides = {
|
|
104
|
+
"provider": provider,
|
|
105
|
+
"model": model,
|
|
106
|
+
"max_steps": max_steps,
|
|
107
|
+
"temperature": temperature,
|
|
108
|
+
"require_plan_approval": False if no_approval else None,
|
|
109
|
+
"auto_commit": True if auto_commit else None,
|
|
110
|
+
}
|
|
111
|
+
cfg = AgentConfig.load(workspace=workspace, overrides=overrides)
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
deps = _build_deps(cfg, api_key)
|
|
115
|
+
except LLMConfigError as exc:
|
|
116
|
+
console.print(f"[bold red]Config error:[/] {exc}")
|
|
117
|
+
raise typer.Exit(code=2)
|
|
118
|
+
|
|
119
|
+
console.print(Panel(
|
|
120
|
+
f"[bold]{task}[/]\n\n"
|
|
121
|
+
f"provider=[cyan]{cfg.provider}[/] model=[cyan]{deps.llm.model}[/] "
|
|
122
|
+
f"workspace=[dim]{cfg.workspace}[/]",
|
|
123
|
+
title="Autonomous Agent", border_style="blue",
|
|
124
|
+
))
|
|
125
|
+
|
|
126
|
+
interactive = sys.stdin.isatty() and not yes and not no_approval and cfg.require_plan_approval
|
|
127
|
+
graph = build_graph(deps, interrupt_for_approval=True)
|
|
128
|
+
thread = {"configurable": {"thread_id": "main"},
|
|
129
|
+
"recursion_limit": max(60, cfg.max_steps * 3 + 30)}
|
|
130
|
+
|
|
131
|
+
state = new_state(task, str(cfg.workspace))
|
|
132
|
+
try:
|
|
133
|
+
graph.invoke(state, thread) # runs context + planning, then interrupts
|
|
134
|
+
except LLMConfigError as exc:
|
|
135
|
+
console.print(f"[bold red]LLM error:[/] {exc}")
|
|
136
|
+
raise typer.Exit(code=2)
|
|
137
|
+
|
|
138
|
+
snapshot = graph.get_state(thread)
|
|
139
|
+
if "approval" in (snapshot.next or ()):
|
|
140
|
+
plan = Plan.model_validate(snapshot.values["plan"])
|
|
141
|
+
_render_plan(plan)
|
|
142
|
+
approved = True
|
|
143
|
+
if interactive:
|
|
144
|
+
approved = Confirm.ask("Approve this plan and begin execution?", default=True)
|
|
145
|
+
graph.update_state(thread, {"plan_approved": approved})
|
|
146
|
+
|
|
147
|
+
# Resume to completion (handles approved, rejected, and failed-planning paths).
|
|
148
|
+
final = graph.invoke(None, thread)
|
|
149
|
+
_finish(final, deps)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _finish(final: dict, deps: Deps) -> None:
|
|
153
|
+
phase = final.get("phase")
|
|
154
|
+
if final.get("report"):
|
|
155
|
+
console.print(Panel(Markdown(final["report"]), title="Final Report",
|
|
156
|
+
border_style="green"))
|
|
157
|
+
if phase == "halted":
|
|
158
|
+
console.print(Panel(final.get("halt_reason", "halted"),
|
|
159
|
+
title="Halted", border_style="red"))
|
|
160
|
+
|
|
161
|
+
tu = final.get("token_usage", {})
|
|
162
|
+
console.print(
|
|
163
|
+
f"[dim]tokens: prompt={tu.get('prompt_tokens', 0)} "
|
|
164
|
+
f"completion={tu.get('completion_tokens', 0)} "
|
|
165
|
+
f"total={tu.get('total_tokens', 0)} | "
|
|
166
|
+
f"plan: {deps.plan_store.md_path}[/]"
|
|
167
|
+
)
|
|
168
|
+
verdict = final.get("verification", {})
|
|
169
|
+
if verdict and not verdict.get("passed", False):
|
|
170
|
+
raise typer.Exit(code=1)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
# ------------------------------------------------------------------ providers
|
|
174
|
+
@app.command()
|
|
175
|
+
def providers() -> None:
|
|
176
|
+
"""List supported LLM providers and whether a key is configured."""
|
|
177
|
+
load_dotenv()
|
|
178
|
+
avail = available_providers()
|
|
179
|
+
table = Table(title="LLM Providers", header_style="bold")
|
|
180
|
+
table.add_column("Provider", style="cyan")
|
|
181
|
+
table.add_column("Default model")
|
|
182
|
+
table.add_column("Key env var")
|
|
183
|
+
table.add_column("Ready", justify="center")
|
|
184
|
+
for name, spec in PROVIDERS.items():
|
|
185
|
+
ready = "[green]ready[/]" if avail[name] else "[dim]-[/]"
|
|
186
|
+
table.add_row(name, spec.default_model, spec.key_env or "(none)", ready)
|
|
187
|
+
console.print(table)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@app.command("plan-show")
|
|
191
|
+
def plan_show(workspace: Path = typer.Option(Path.cwd())) -> None:
|
|
192
|
+
"""Print the persisted plan for a workspace, if any."""
|
|
193
|
+
store = PlanStore((workspace / ".agent"))
|
|
194
|
+
if not store.exists():
|
|
195
|
+
console.print("[yellow]No plan found in .agent/[/]")
|
|
196
|
+
raise typer.Exit(code=1)
|
|
197
|
+
console.print(Markdown(store.load().to_markdown()))
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
if __name__ == "__main__":
|
|
201
|
+
app()
|
auzek/config.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""Runtime configuration for the agent.
|
|
2
|
+
|
|
3
|
+
Resolution order (highest priority first):
|
|
4
|
+
1. Explicit CLI flags
|
|
5
|
+
2. Environment variables (AGENT_*)
|
|
6
|
+
3. config.yaml in the working directory
|
|
7
|
+
4. Hard-coded defaults below
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import os
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
import yaml
|
|
17
|
+
from pydantic import BaseModel, Field
|
|
18
|
+
|
|
19
|
+
DEFAULT_DENY_GLOBS = [
|
|
20
|
+
".git/**",
|
|
21
|
+
"**/.env",
|
|
22
|
+
"**/node_modules/**",
|
|
23
|
+
"**/.venv/**",
|
|
24
|
+
"**/__pycache__/**",
|
|
25
|
+
"**/dist/**",
|
|
26
|
+
"**/build/**",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class AgentConfig(BaseModel):
|
|
31
|
+
"""All knobs that control a single agent run."""
|
|
32
|
+
|
|
33
|
+
# --- Model selection ---
|
|
34
|
+
provider: str = "anthropic"
|
|
35
|
+
model: str = "claude-sonnet-4-6"
|
|
36
|
+
temperature: float = 0.0
|
|
37
|
+
max_tokens: int = 8192
|
|
38
|
+
|
|
39
|
+
# --- Behaviour ---
|
|
40
|
+
max_recovery_attempts: int = 3
|
|
41
|
+
max_steps: int = 40
|
|
42
|
+
auto_commit: bool = False
|
|
43
|
+
require_plan_approval: bool = True
|
|
44
|
+
|
|
45
|
+
# --- Verification commands (auto-detected when blank) ---
|
|
46
|
+
test_command: str = ""
|
|
47
|
+
lint_command: str = ""
|
|
48
|
+
typecheck_command: str = ""
|
|
49
|
+
|
|
50
|
+
# --- Safety ---
|
|
51
|
+
deny_globs: list[str] = Field(default_factory=lambda: list(DEFAULT_DENY_GLOBS))
|
|
52
|
+
|
|
53
|
+
# --- Paths (filled in at load time) ---
|
|
54
|
+
workspace: Path = Field(default_factory=Path.cwd)
|
|
55
|
+
|
|
56
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def load(
|
|
60
|
+
cls,
|
|
61
|
+
workspace: Path | None = None,
|
|
62
|
+
overrides: dict[str, Any] | None = None,
|
|
63
|
+
) -> "AgentConfig":
|
|
64
|
+
"""Build a config from yaml + env + overrides."""
|
|
65
|
+
workspace = (workspace or Path.cwd()).resolve()
|
|
66
|
+
data: dict[str, Any] = {}
|
|
67
|
+
|
|
68
|
+
# 1. config.yaml
|
|
69
|
+
cfg_path = workspace / "config.yaml"
|
|
70
|
+
if cfg_path.exists():
|
|
71
|
+
loaded = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) or {}
|
|
72
|
+
data.update({k: v for k, v in loaded.items() if v not in (None, "")})
|
|
73
|
+
|
|
74
|
+
# 2. environment variables (AGENT_PROVIDER, AGENT_MODEL, ...)
|
|
75
|
+
for field in cls.model_fields:
|
|
76
|
+
env_key = f"AGENT_{field.upper()}"
|
|
77
|
+
if env_key in os.environ and os.environ[env_key] != "":
|
|
78
|
+
data[field] = os.environ[env_key]
|
|
79
|
+
|
|
80
|
+
# 3. explicit overrides (CLI)
|
|
81
|
+
if overrides:
|
|
82
|
+
data.update({k: v for k, v in overrides.items() if v is not None})
|
|
83
|
+
|
|
84
|
+
data["workspace"] = workspace
|
|
85
|
+
return cls(**data)
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def state_dir(self) -> Path:
|
|
89
|
+
"""Where the agent persists its plan and run state."""
|
|
90
|
+
d = self.workspace / ".agent"
|
|
91
|
+
d.mkdir(exist_ok=True)
|
|
92
|
+
return d
|
auzek/graph.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""Assemble the agent as a LangGraph state machine.
|
|
2
|
+
|
|
3
|
+
Flow:
|
|
4
|
+
context → planning → [approval gate] → execution ⇄ recovery → verification → report
|
|
5
|
+
|
|
6
|
+
Conditional edges route on ``state['phase']`` so the same execution node can loop
|
|
7
|
+
over many steps, and recovery can loop until it succeeds or hits its attempt cap.
|
|
8
|
+
A checkpointer + interrupt after planning lets the CLI insert human approval.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from functools import partial
|
|
14
|
+
|
|
15
|
+
from langgraph.checkpoint.memory import MemorySaver
|
|
16
|
+
from langgraph.graph import END, START, StateGraph
|
|
17
|
+
|
|
18
|
+
from .nodes import (
|
|
19
|
+
context_node,
|
|
20
|
+
execution_node,
|
|
21
|
+
planning_node,
|
|
22
|
+
recovery_node,
|
|
23
|
+
report_node,
|
|
24
|
+
verification_node,
|
|
25
|
+
)
|
|
26
|
+
from .nodes.approval import approval_node
|
|
27
|
+
from .runtime import Deps
|
|
28
|
+
from .state import AgentState
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# --------------------------------------------------------------------- routers
|
|
32
|
+
def _after_planning(state: AgentState) -> str:
|
|
33
|
+
return "report_node" if state.get("phase") == "halted" else "approval"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _after_approval(state: AgentState) -> str:
|
|
37
|
+
return "execution" if state.get("phase") == "execution" else "report_node"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _after_execution(state: AgentState) -> str:
|
|
41
|
+
phase = state.get("phase")
|
|
42
|
+
if phase == "recovery":
|
|
43
|
+
return "recovery"
|
|
44
|
+
if phase == "verification":
|
|
45
|
+
return "verify"
|
|
46
|
+
if phase == "halted":
|
|
47
|
+
return "report_node"
|
|
48
|
+
return "execution" # more steps remain
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _after_recovery(state: AgentState) -> str:
|
|
52
|
+
phase = state.get("phase")
|
|
53
|
+
if phase == "execution":
|
|
54
|
+
return "execution"
|
|
55
|
+
if phase == "verification":
|
|
56
|
+
return "verify"
|
|
57
|
+
if phase == "halted":
|
|
58
|
+
return "report_node"
|
|
59
|
+
return "recovery" # keep trying (attempt cap enforced inside the node)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def build_graph(deps: Deps, *, interrupt_for_approval: bool = True):
|
|
63
|
+
"""Compile the agent graph bound to a set of dependencies."""
|
|
64
|
+
g = StateGraph(AgentState)
|
|
65
|
+
|
|
66
|
+
g.add_node("context", partial(context_node, deps=deps))
|
|
67
|
+
g.add_node("planning", partial(planning_node, deps=deps))
|
|
68
|
+
g.add_node("approval", partial(approval_node, deps=deps))
|
|
69
|
+
g.add_node("execution", partial(execution_node, deps=deps))
|
|
70
|
+
g.add_node("recovery", partial(recovery_node, deps=deps))
|
|
71
|
+
g.add_node("verify", partial(verification_node, deps=deps))
|
|
72
|
+
g.add_node("report_node", partial(report_node, deps=deps))
|
|
73
|
+
|
|
74
|
+
g.add_edge(START, "context")
|
|
75
|
+
g.add_edge("context", "planning")
|
|
76
|
+
g.add_conditional_edges("planning", _after_planning,
|
|
77
|
+
{"approval": "approval", "report_node": "report_node"})
|
|
78
|
+
g.add_conditional_edges("approval", _after_approval,
|
|
79
|
+
{"execution": "execution", "report_node": "report_node"})
|
|
80
|
+
g.add_conditional_edges("execution", _after_execution,
|
|
81
|
+
{"execution": "execution", "recovery": "recovery",
|
|
82
|
+
"verify": "verify", "report_node": "report_node"})
|
|
83
|
+
g.add_conditional_edges("recovery", _after_recovery,
|
|
84
|
+
{"execution": "execution", "recovery": "recovery",
|
|
85
|
+
"verify": "verify", "report_node": "report_node"})
|
|
86
|
+
g.add_edge("verify", "report_node")
|
|
87
|
+
g.add_edge("report_node", END)
|
|
88
|
+
|
|
89
|
+
checkpointer = MemorySaver()
|
|
90
|
+
compile_kwargs = {"checkpointer": checkpointer}
|
|
91
|
+
if interrupt_for_approval:
|
|
92
|
+
# Pause after planning so a human can approve the plan before any edits.
|
|
93
|
+
compile_kwargs["interrupt_after"] = ["planning"]
|
|
94
|
+
return g.compile(**compile_kwargs)
|
auzek/llm.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
"""Multi-provider LLM gateway.
|
|
2
|
+
|
|
3
|
+
Built on LiteLLM so a single code path talks to Anthropic, OpenAI, Groq,
|
|
4
|
+
Google, Mistral, DeepSeek and local Ollama models. Users supply their own
|
|
5
|
+
API keys via environment variables (.env) or the CLI.
|
|
6
|
+
|
|
7
|
+
The rest of the agent only ever sees `LLM.chat(...)` and a normalized
|
|
8
|
+
`LLMResponse`, so swapping providers never touches node logic.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import os
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
import litellm
|
|
18
|
+
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
|
19
|
+
|
|
20
|
+
# Don't let LiteLLM phone home or spam logs.
|
|
21
|
+
litellm.telemetry = False
|
|
22
|
+
litellm.drop_params = True # silently drop params a given provider doesn't support
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class ProviderSpec:
|
|
27
|
+
"""How to reach a provider: the LiteLLM prefix and the env var for its key."""
|
|
28
|
+
|
|
29
|
+
prefix: str # LiteLLM model prefix, e.g. "groq"
|
|
30
|
+
key_env: str # env var holding the API key, e.g. "GROQ_API_KEY"
|
|
31
|
+
default_model: str # a sensible default model id for this provider
|
|
32
|
+
needs_key: bool = True
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# Friendly provider name -> how to call it. Add new providers here only.
|
|
36
|
+
PROVIDERS: dict[str, ProviderSpec] = {
|
|
37
|
+
"anthropic": ProviderSpec("anthropic", "ANTHROPIC_API_KEY", "claude-sonnet-4-6"),
|
|
38
|
+
"openai": ProviderSpec("openai", "OPENAI_API_KEY", "gpt-4o"),
|
|
39
|
+
"groq": ProviderSpec("groq", "GROQ_API_KEY", "llama-3.3-70b-versatile"),
|
|
40
|
+
"google": ProviderSpec("gemini", "GEMINI_API_KEY", "gemini-1.5-pro"),
|
|
41
|
+
"mistral": ProviderSpec("mistral", "MISTRAL_API_KEY", "mistral-large-latest"),
|
|
42
|
+
"deepseek": ProviderSpec("deepseek", "DEEPSEEK_API_KEY", "deepseek-chat"),
|
|
43
|
+
"ollama": ProviderSpec("ollama", "", "qwen2.5-coder:7b", needs_key=False),
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class LLMConfigError(RuntimeError):
|
|
48
|
+
"""Raised when a provider/model/key combination is unusable."""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class LLMResponse:
|
|
53
|
+
"""Normalized model response."""
|
|
54
|
+
|
|
55
|
+
content: str
|
|
56
|
+
tool_calls: list[dict[str, Any]] = field(default_factory=list)
|
|
57
|
+
finish_reason: str = "stop"
|
|
58
|
+
usage: dict[str, int] = field(default_factory=dict)
|
|
59
|
+
raw: Any = None
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def wants_tools(self) -> bool:
|
|
63
|
+
return bool(self.tool_calls)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class LLM:
|
|
67
|
+
"""A thin, provider-agnostic chat client with tool-calling support."""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
provider: str,
|
|
72
|
+
model: str | None = None,
|
|
73
|
+
*,
|
|
74
|
+
temperature: float = 0.0,
|
|
75
|
+
max_tokens: int = 8192,
|
|
76
|
+
api_key: str | None = None,
|
|
77
|
+
) -> None:
|
|
78
|
+
provider = provider.lower().strip()
|
|
79
|
+
if provider not in PROVIDERS:
|
|
80
|
+
raise LLMConfigError(
|
|
81
|
+
f"Unknown provider '{provider}'. "
|
|
82
|
+
f"Choose one of: {', '.join(PROVIDERS)}"
|
|
83
|
+
)
|
|
84
|
+
self.provider = provider
|
|
85
|
+
self.spec = PROVIDERS[provider]
|
|
86
|
+
self.model = model or self.spec.default_model
|
|
87
|
+
self.temperature = temperature
|
|
88
|
+
self.max_tokens = max_tokens
|
|
89
|
+
self.api_key = api_key or os.environ.get(self.spec.key_env) or None
|
|
90
|
+
|
|
91
|
+
if self.spec.needs_key and not self.api_key:
|
|
92
|
+
raise LLMConfigError(
|
|
93
|
+
f"No API key for provider '{provider}'. "
|
|
94
|
+
f"Set {self.spec.key_env} in your environment / .env, "
|
|
95
|
+
f"or pass --api-key."
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def model_string(self) -> str:
|
|
100
|
+
"""LiteLLM-style 'prefix/model'. Avoids double-prefixing."""
|
|
101
|
+
if self.model.startswith(f"{self.spec.prefix}/"):
|
|
102
|
+
return self.model
|
|
103
|
+
return f"{self.spec.prefix}/{self.model}"
|
|
104
|
+
|
|
105
|
+
def _extra_kwargs(self) -> dict[str, Any]:
|
|
106
|
+
kwargs: dict[str, Any] = {}
|
|
107
|
+
if self.api_key:
|
|
108
|
+
kwargs["api_key"] = self.api_key
|
|
109
|
+
if self.provider == "ollama":
|
|
110
|
+
kwargs["api_base"] = os.environ.get("OLLAMA_API_BASE", "http://localhost:11434")
|
|
111
|
+
return kwargs
|
|
112
|
+
|
|
113
|
+
@retry(
|
|
114
|
+
retry=retry_if_exception_type(
|
|
115
|
+
(litellm.RateLimitError, litellm.APIConnectionError, litellm.Timeout)
|
|
116
|
+
),
|
|
117
|
+
wait=wait_exponential(multiplier=2, min=2, max=30),
|
|
118
|
+
stop=stop_after_attempt(4),
|
|
119
|
+
reraise=True,
|
|
120
|
+
)
|
|
121
|
+
def chat(
|
|
122
|
+
self,
|
|
123
|
+
messages: list[dict[str, Any]],
|
|
124
|
+
*,
|
|
125
|
+
tools: list[dict[str, Any]] | None = None,
|
|
126
|
+
tool_choice: str | None = None,
|
|
127
|
+
temperature: float | None = None,
|
|
128
|
+
) -> LLMResponse:
|
|
129
|
+
"""Send a chat completion request and normalize the response."""
|
|
130
|
+
call_kwargs: dict[str, Any] = {
|
|
131
|
+
"model": self.model_string,
|
|
132
|
+
"messages": messages,
|
|
133
|
+
"temperature": self.temperature if temperature is None else temperature,
|
|
134
|
+
"max_tokens": self.max_tokens,
|
|
135
|
+
**self._extra_kwargs(),
|
|
136
|
+
}
|
|
137
|
+
if tools:
|
|
138
|
+
call_kwargs["tools"] = tools
|
|
139
|
+
call_kwargs["tool_choice"] = tool_choice or "auto"
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
resp = litellm.completion(**call_kwargs)
|
|
143
|
+
except litellm.AuthenticationError as exc: # bad/expired key
|
|
144
|
+
raise LLMConfigError(
|
|
145
|
+
f"Authentication failed for '{self.provider}'. Check {self.spec.key_env}."
|
|
146
|
+
) from exc
|
|
147
|
+
|
|
148
|
+
choice = resp.choices[0]
|
|
149
|
+
msg = choice.message
|
|
150
|
+
|
|
151
|
+
tool_calls: list[dict[str, Any]] = []
|
|
152
|
+
for tc in getattr(msg, "tool_calls", None) or []:
|
|
153
|
+
tool_calls.append(
|
|
154
|
+
{
|
|
155
|
+
"id": tc.id,
|
|
156
|
+
"name": tc.function.name,
|
|
157
|
+
"arguments": tc.function.arguments, # JSON string
|
|
158
|
+
}
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
usage = {}
|
|
162
|
+
if getattr(resp, "usage", None):
|
|
163
|
+
usage = {
|
|
164
|
+
"prompt_tokens": resp.usage.prompt_tokens or 0,
|
|
165
|
+
"completion_tokens": resp.usage.completion_tokens or 0,
|
|
166
|
+
"total_tokens": resp.usage.total_tokens or 0,
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
return LLMResponse(
|
|
170
|
+
content=msg.content or "",
|
|
171
|
+
tool_calls=tool_calls,
|
|
172
|
+
finish_reason=choice.finish_reason or "stop",
|
|
173
|
+
usage=usage,
|
|
174
|
+
raw=resp,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def available_providers() -> dict[str, bool]:
|
|
179
|
+
"""Map provider name -> whether a usable API key is present."""
|
|
180
|
+
out: dict[str, bool] = {}
|
|
181
|
+
for name, spec in PROVIDERS.items():
|
|
182
|
+
out[name] = (not spec.needs_key) or bool(os.environ.get(spec.key_env))
|
|
183
|
+
return out
|
auzek/memory/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Durable agent memory: the plan and run state that survive crashes."""
|