tau-coding-agent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tau/__init__.py +0 -0
- tau/agent/__init__.py +11 -0
- tau/agent/prompt/__init__.py +10 -0
- tau/agent/prompt/builder.py +302 -0
- tau/agent/prompt/types.py +33 -0
- tau/agent/service.py +369 -0
- tau/agent/types.py +61 -0
- tau/auth/manager.py +247 -0
- tau/auth/storage.py +82 -0
- tau/auth/types.py +41 -0
- tau/builtins/__init__.py +4 -0
- tau/builtins/__pycache__/__init__.cpython-313.pyc +0 -0
- tau/builtins/__pycache__/__init__.cpython-314.pyc +0 -0
- tau/builtins/commands/__init__.py +41 -0
- tau/builtins/commands/__pycache__/__init__.cpython-313.pyc +0 -0
- tau/builtins/commands/__pycache__/__init__.cpython-314.pyc +0 -0
- tau/builtins/commands/__pycache__/clear.cpython-313.pyc +0 -0
- tau/builtins/commands/__pycache__/clear.cpython-314.pyc +0 -0
- tau/builtins/commands/__pycache__/compact.cpython-313.pyc +0 -0
- tau/builtins/commands/__pycache__/compact.cpython-314.pyc +0 -0
- tau/builtins/commands/__pycache__/reload.cpython-313.pyc +0 -0
- tau/builtins/commands/__pycache__/reload.cpython-314.pyc +0 -0
- tau/builtins/commands/__pycache__/session.cpython-313.pyc +0 -0
- tau/builtins/commands/__pycache__/session.cpython-314.pyc +0 -0
- tau/builtins/commands/clear.py +16 -0
- tau/builtins/commands/compact.py +28 -0
- tau/builtins/commands/reload.py +27 -0
- tau/builtins/commands/session.py +19 -0
- tau/builtins/extensions/footer/__init__.py +76 -0
- tau/builtins/extensions/footer/__pycache__/__init__.cpython-313.pyc +0 -0
- tau/builtins/extensions/footer/__pycache__/git.cpython-313.pyc +0 -0
- tau/builtins/extensions/footer/__pycache__/model.cpython-313.pyc +0 -0
- tau/builtins/extensions/footer/__pycache__/utils.cpython-313.pyc +0 -0
- tau/builtins/extensions/footer/git.py +26 -0
- tau/builtins/extensions/footer/model.py +69 -0
- tau/builtins/extensions/footer/utils.py +44 -0
- tau/builtins/extensions/header/__init__.py +18 -0
- tau/builtins/extensions/header/__pycache__/__init__.cpython-313.pyc +0 -0
- tau/builtins/models/__init__.py +0 -0
- tau/builtins/models/__pycache__/__init__.cpython-313.pyc +0 -0
- tau/builtins/models/__pycache__/text.cpython-313.pyc +0 -0
- tau/builtins/models/audio.py +43 -0
- tau/builtins/models/image.py +43 -0
- tau/builtins/models/text.py +482 -0
- tau/builtins/models/video.py +40 -0
- tau/builtins/prompts/commit.md +7 -0
- tau/builtins/prompts/docs.md +7 -0
- tau/builtins/prompts/explain.md +7 -0
- tau/builtins/prompts/fix.md +7 -0
- tau/builtins/prompts/refactor.md +7 -0
- tau/builtins/prompts/review.md +7 -0
- tau/builtins/prompts/test.md +7 -0
- tau/builtins/providers/__init__.py +0 -0
- tau/builtins/providers/__pycache__/__init__.cpython-313.pyc +0 -0
- tau/builtins/providers/__pycache__/text.cpython-313.pyc +0 -0
- tau/builtins/providers/audio.py +10 -0
- tau/builtins/providers/image.py +9 -0
- tau/builtins/providers/text.py +33 -0
- tau/builtins/providers/video.py +6 -0
- tau/builtins/skills/code-review/SKILL.md +4 -0
- tau/builtins/skills/debug/SKILL.md +4 -0
- tau/builtins/skills/git-commit/SKILL.md +4 -0
- tau/builtins/themes/dark.yaml +1 -0
- tau/builtins/themes/light.yaml +46 -0
- tau/builtins/tools/__init__.py +73 -0
- tau/builtins/tools/__pycache__/__init__.cpython-313.pyc +0 -0
- tau/builtins/tools/__pycache__/__init__.cpython-314.pyc +0 -0
- tau/builtins/tools/__pycache__/bash.cpython-313.pyc +0 -0
- tau/builtins/tools/__pycache__/bash.cpython-314.pyc +0 -0
- tau/builtins/tools/__pycache__/edit.cpython-313.pyc +0 -0
- tau/builtins/tools/__pycache__/edit.cpython-314.pyc +0 -0
- tau/builtins/tools/__pycache__/glob.cpython-313.pyc +0 -0
- tau/builtins/tools/__pycache__/glob.cpython-314.pyc +0 -0
- tau/builtins/tools/__pycache__/grep.cpython-313.pyc +0 -0
- tau/builtins/tools/__pycache__/grep.cpython-314.pyc +0 -0
- tau/builtins/tools/__pycache__/ls.cpython-313.pyc +0 -0
- tau/builtins/tools/__pycache__/ls.cpython-314.pyc +0 -0
- tau/builtins/tools/__pycache__/read.cpython-313.pyc +0 -0
- tau/builtins/tools/__pycache__/read.cpython-314.pyc +0 -0
- tau/builtins/tools/__pycache__/terminal.cpython-313.pyc +0 -0
- tau/builtins/tools/__pycache__/terminal.cpython-314.pyc +0 -0
- tau/builtins/tools/__pycache__/write.cpython-313.pyc +0 -0
- tau/builtins/tools/__pycache__/write.cpython-314.pyc +0 -0
- tau/builtins/tools/edit.py +215 -0
- tau/builtins/tools/glob.py +112 -0
- tau/builtins/tools/grep.py +146 -0
- tau/builtins/tools/ls.py +135 -0
- tau/builtins/tools/read.py +122 -0
- tau/builtins/tools/terminal.py +150 -0
- tau/builtins/tools/write.py +105 -0
- tau/commands/__init__.py +10 -0
- tau/commands/registry.py +71 -0
- tau/commands/types.py +33 -0
- tau/console/__init__.py +0 -0
- tau/console/cli.py +266 -0
- tau/console/commands/__init__.py +0 -0
- tau/console/commands/auth.py +193 -0
- tau/console/commands/packages.py +104 -0
- tau/console/commands/update.py +76 -0
- tau/core/__init__.py +0 -0
- tau/core/registry.py +102 -0
- tau/engine/__init__.py +47 -0
- tau/engine/service.py +768 -0
- tau/engine/types.py +163 -0
- tau/extensions/__init__.py +28 -0
- tau/extensions/api.py +928 -0
- tau/extensions/context.py +462 -0
- tau/extensions/events.py +70 -0
- tau/extensions/loader.py +386 -0
- tau/extensions/runtime.py +184 -0
- tau/extensions/settings.py +137 -0
- tau/hooks/__init__.py +112 -0
- tau/hooks/engine.py +237 -0
- tau/hooks/inference.py +21 -0
- tau/hooks/runtime.py +126 -0
- tau/hooks/service.py +121 -0
- tau/hooks/session.py +117 -0
- tau/hooks/tui.py +61 -0
- tau/hooks/types.py +72 -0
- tau/inference/__init__.py +180 -0
- tau/inference/api/__init__.py +0 -0
- tau/inference/api/audio/__init__.py +0 -0
- tau/inference/api/audio/base.py +29 -0
- tau/inference/api/audio/builtins.py +15 -0
- tau/inference/api/audio/elevenlabs_audio.py +183 -0
- tau/inference/api/audio/gemini_audio.py +95 -0
- tau/inference/api/audio/openai_audio.py +159 -0
- tau/inference/api/audio/registry.py +15 -0
- tau/inference/api/audio/sarvam_audio.py +163 -0
- tau/inference/api/audio/service.py +103 -0
- tau/inference/api/audio/utils.py +47 -0
- tau/inference/api/image/__init__.py +0 -0
- tau/inference/api/image/base.py +17 -0
- tau/inference/api/image/builtins.py +8 -0
- tau/inference/api/image/gemini_image.py +77 -0
- tau/inference/api/image/openai_image.py +103 -0
- tau/inference/api/image/openrouter.py +144 -0
- tau/inference/api/image/registry.py +15 -0
- tau/inference/api/image/service.py +71 -0
- tau/inference/api/registry.py +82 -0
- tau/inference/api/text/__init__.py +0 -0
- tau/inference/api/text/anthropic_claude_code.py +222 -0
- tau/inference/api/text/anthropic_messages.py +196 -0
- tau/inference/api/text/base.py +40 -0
- tau/inference/api/text/builtins.py +19 -0
- tau/inference/api/text/gemini_generate.py +234 -0
- tau/inference/api/text/github_copilot_chat.py +172 -0
- tau/inference/api/text/google_antigravity.py +522 -0
- tau/inference/api/text/mistral_chat.py +284 -0
- tau/inference/api/text/ollama_chat.py +200 -0
- tau/inference/api/text/openai_codex_responses.py +497 -0
- tau/inference/api/text/openai_completions.py +227 -0
- tau/inference/api/text/openai_responses.py +235 -0
- tau/inference/api/text/registry.py +50 -0
- tau/inference/api/text/service.py +297 -0
- tau/inference/api/text/types.py +7 -0
- tau/inference/api/text/utils.py +228 -0
- tau/inference/api/video/__init__.py +0 -0
- tau/inference/api/video/base.py +26 -0
- tau/inference/api/video/builtins.py +7 -0
- tau/inference/api/video/fal_video.py +119 -0
- tau/inference/api/video/openrouter_video.py +142 -0
- tau/inference/api/video/registry.py +15 -0
- tau/inference/api/video/service.py +72 -0
- tau/inference/model/__init__.py +0 -0
- tau/inference/model/registry.py +102 -0
- tau/inference/model/types.py +65 -0
- tau/inference/provider/__init__.py +0 -0
- tau/inference/provider/oauth/__init__.py +35 -0
- tau/inference/provider/oauth/anthropic_claude_code.py +286 -0
- tau/inference/provider/oauth/github_copilot.py +333 -0
- tau/inference/provider/oauth/google_antigravity.py +258 -0
- tau/inference/provider/oauth/openai_codex.py +309 -0
- tau/inference/provider/oauth/pkce.py +14 -0
- tau/inference/provider/oauth/types.py +46 -0
- tau/inference/provider/oauth/utils.py +154 -0
- tau/inference/provider/registry.py +141 -0
- tau/inference/provider/types.py +114 -0
- tau/inference/types.py +549 -0
- tau/inference/utils.py +219 -0
- tau/message/__init__.py +0 -0
- tau/message/types.py +482 -0
- tau/message/utils.py +178 -0
- tau/packages/__init__.py +11 -0
- tau/packages/manager.py +190 -0
- tau/packages/types.py +20 -0
- tau/packages/utils.py +67 -0
- tau/prompts/expand.py +58 -0
- tau/prompts/loader.py +69 -0
- tau/prompts/registry.py +45 -0
- tau/prompts/types.py +24 -0
- tau/rpc/__init__.py +8 -0
- tau/rpc/mode.py +783 -0
- tau/rpc/types.py +252 -0
- tau/runtime/service.py +759 -0
- tau/runtime/types.py +303 -0
- tau/session/branch_summarization.py +312 -0
- tau/session/compaction.py +646 -0
- tau/session/manager.py +652 -0
- tau/session/types.py +188 -0
- tau/session/utils.py +233 -0
- tau/settings/manager.py +1077 -0
- tau/settings/paths.py +150 -0
- tau/settings/storage.py +63 -0
- tau/settings/types.py +173 -0
- tau/settings/utils.py +25 -0
- tau/skills/loader.py +91 -0
- tau/skills/registry.py +70 -0
- tau/skills/types.py +25 -0
- tau/themes/loader.py +238 -0
- tau/themes/registry.py +108 -0
- tau/themes/types.py +19 -0
- tau/tool/__init__.py +3 -0
- tau/tool/registry.py +117 -0
- tau/tool/render.py +21 -0
- tau/tool/types.py +244 -0
- tau/trust/__init__.py +13 -0
- tau/trust/manager.py +80 -0
- tau/trust/types.py +14 -0
- tau/trust/utils.py +72 -0
- tau/tui/__init__.py +54 -0
- tau/tui/agent_hooks.py +346 -0
- tau/tui/ansi.py +330 -0
- tau/tui/app.py +540 -0
- tau/tui/autocomplete.py +33 -0
- tau/tui/capabilities.py +119 -0
- tau/tui/commands/__init__.py +3 -0
- tau/tui/commands/appearance.py +498 -0
- tau/tui/commands/auth.py +232 -0
- tau/tui/commands/context.py +38 -0
- tau/tui/commands/misc.py +82 -0
- tau/tui/commands/model.py +118 -0
- tau/tui/commands/session.py +464 -0
- tau/tui/component.py +268 -0
- tau/tui/components/__init__.py +0 -0
- tau/tui/components/autocomplete_manager.py +267 -0
- tau/tui/components/autocomplete_picker.py +143 -0
- tau/tui/components/box.py +90 -0
- tau/tui/components/command_palette.py +144 -0
- tau/tui/components/dynamic_border.py +19 -0
- tau/tui/components/file_picker.py +233 -0
- tau/tui/components/image.py +181 -0
- tau/tui/components/inline_selector.py +71 -0
- tau/tui/components/layout.py +1194 -0
- tau/tui/components/message_list.py +692 -0
- tau/tui/components/modal.py +97 -0
- tau/tui/components/model_palette.py +204 -0
- tau/tui/components/picker_overlay.py +174 -0
- tau/tui/components/prompt_overlay.py +236 -0
- tau/tui/components/resume_modal.py +372 -0
- tau/tui/components/select_list.py +222 -0
- tau/tui/components/settings_modal.py +274 -0
- tau/tui/components/settings_schema.py +203 -0
- tau/tui/components/spinner.py +119 -0
- tau/tui/components/text_input.py +396 -0
- tau/tui/components/text_prompt.py +82 -0
- tau/tui/components/tree_select_list.py +580 -0
- tau/tui/components/trust_screen.py +97 -0
- tau/tui/diff.py +114 -0
- tau/tui/fuzzy.py +99 -0
- tau/tui/input.py +496 -0
- tau/tui/input_handler.py +716 -0
- tau/tui/keybindings.py +87 -0
- tau/tui/markdown.py +286 -0
- tau/tui/message_renderers.py +31 -0
- tau/tui/overlay.py +326 -0
- tau/tui/renderer.py +378 -0
- tau/tui/terminal.py +499 -0
- tau/tui/theme.py +148 -0
- tau/tui/tui.py +544 -0
- tau/tui/ui_context.py +768 -0
- tau/tui/utils.py +20 -0
- tau/utils/__init__.py +0 -0
- tau/utils/http_proxy.py +221 -0
- tau/utils/image_processing.py +172 -0
- tau/utils/secrets.py +59 -0
- tau/utils/version_check.py +60 -0
- tau_coding_agent-0.1.0.dist-info/METADATA +177 -0
- tau_coding_agent-0.1.0.dist-info/RECORD +283 -0
- tau_coding_agent-0.1.0.dist-info/WHEEL +5 -0
- tau_coding_agent-0.1.0.dist-info/entry_points.txt +2 -0
- tau_coding_agent-0.1.0.dist-info/licenses/LICENSE +21 -0
- tau_coding_agent-0.1.0.dist-info/top_level.txt +1 -0
tau/agent/service.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from tau.agent.types import AgentConfig, AgentContext, AgentPhase, PromptOptions, ContextUsage
|
|
8
|
+
from tau.hooks.service import Hooks
|
|
9
|
+
from tau.hooks.engine import MessageEndEvent, MessageRollbackEvent, SavePointEvent, SettledEvent
|
|
10
|
+
from tau.message.types import AgentMessage, AssistantMessage, TerminalExecutionMessage, LLMMessage, UserMessage, TextContent, ToolMessage
|
|
11
|
+
from tau.message.utils import strip_unusable_trailing_assistant
|
|
12
|
+
from tau.tool.types import ToolInvocation, ToolResult
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from tau.engine.service import Engine
|
|
16
|
+
from tau.session.manager import SessionManager
|
|
17
|
+
from tau.runtime.service import Runtime
|
|
18
|
+
from tau.session.compaction import CompactionPreparation
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _to_llm_messages(messages: list[AgentMessage]) -> list[LLMMessage]:
|
|
23
|
+
"""Convert AgentMessages to LLM-compatible messages.
|
|
24
|
+
|
|
25
|
+
TerminalExecutionMessage → UserMessage (Ran `cmd`\n```output```)
|
|
26
|
+
CompactionSummaryMessage → UserMessage with summary wrapped in XML tags
|
|
27
|
+
CustomMessage and other non-LLM types → skipped
|
|
28
|
+
Empty AssistantMessages are visual-only markers (aborts, persisted API/credit
|
|
29
|
+
errors) and are skipped — an assistant turn with neither content nor tool
|
|
30
|
+
calls is invalid to send back and triggers provider 400s.
|
|
31
|
+
"""
|
|
32
|
+
from tau.message.types import CompactionSummaryMessage, ToolCallContent, ThinkingContent
|
|
33
|
+
result: list[LLMMessage] = []
|
|
34
|
+
for msg in messages:
|
|
35
|
+
if isinstance(msg, CompactionSummaryMessage):
|
|
36
|
+
text = f"<context-summary>\n{msg.summary}\n</context-summary>"
|
|
37
|
+
result.append(UserMessage.from_text(text))
|
|
38
|
+
elif isinstance(msg, TerminalExecutionMessage):
|
|
39
|
+
if not msg.exclude:
|
|
40
|
+
result.append(msg.to_user_message())
|
|
41
|
+
elif isinstance(msg, AssistantMessage):
|
|
42
|
+
has_usable = any(
|
|
43
|
+
isinstance(c, (TextContent, ToolCallContent, ThinkingContent))
|
|
44
|
+
for c in msg.contents
|
|
45
|
+
)
|
|
46
|
+
if has_usable:
|
|
47
|
+
result.append(msg)
|
|
48
|
+
elif isinstance(msg, (UserMessage, ToolMessage)):
|
|
49
|
+
result.append(msg)
|
|
50
|
+
return result
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class Agent:
|
|
54
|
+
"""
|
|
55
|
+
High-level agent session tying together Engine and SessionManager.
|
|
56
|
+
|
|
57
|
+
Call `invoke()` to run a user turn. The session persists each message
|
|
58
|
+
and tracks token usage.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
engine: Engine,
|
|
64
|
+
session_manager: SessionManager,
|
|
65
|
+
config: AgentConfig,
|
|
66
|
+
hooks: Hooks | None = None,
|
|
67
|
+
) -> None:
|
|
68
|
+
self._engine = engine
|
|
69
|
+
self._session_manager = session_manager
|
|
70
|
+
self._config = config
|
|
71
|
+
self._system_prompt: str = config.system_prompt
|
|
72
|
+
self._context_tokens: int = 0
|
|
73
|
+
self._context_window: int = config.context_window
|
|
74
|
+
self._runtime: Runtime | None = None
|
|
75
|
+
self.hooks = hooks or Hooks()
|
|
76
|
+
|
|
77
|
+
self._phase: AgentPhase = AgentPhase.IDLE
|
|
78
|
+
self._signal: asyncio.Event = asyncio.Event()
|
|
79
|
+
self._compaction_failures: int = 0
|
|
80
|
+
self._engine.options.before_tool_call = self._before_tool_call
|
|
81
|
+
self._engine.options.after_tool_call = self._after_tool_call
|
|
82
|
+
|
|
83
|
+
# -------------------------------------------------------------------------
|
|
84
|
+
# Public interface
|
|
85
|
+
# -------------------------------------------------------------------------
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def cwd(self) -> Path:
|
|
89
|
+
"""Get the current working directory."""
|
|
90
|
+
return self._config.cwd
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def session_manager(self) -> SessionManager:
|
|
94
|
+
"""Get the session manager instance."""
|
|
95
|
+
return self._session_manager
|
|
96
|
+
|
|
97
|
+
def is_idle(self) -> bool:
|
|
98
|
+
"""Check if the agent is idle (not processing)."""
|
|
99
|
+
return self._engine.is_idle
|
|
100
|
+
|
|
101
|
+
def has_pending_messages(self) -> bool:
|
|
102
|
+
"""Check if there are pending messages in the queue."""
|
|
103
|
+
return self._engine.has_pending_messages()
|
|
104
|
+
|
|
105
|
+
def abort(self) -> None:
|
|
106
|
+
"""Request abort of current operation."""
|
|
107
|
+
self._signal.set()
|
|
108
|
+
|
|
109
|
+
def shutdown(self) -> None:
|
|
110
|
+
"""Shutdown the agent."""
|
|
111
|
+
self._signal.set()
|
|
112
|
+
|
|
113
|
+
def update_context_tokens(self) -> None:
|
|
114
|
+
"""Recalculate context token usage."""
|
|
115
|
+
from tau.session.compaction import estimate_context_tokens
|
|
116
|
+
session_ctx = self._session_manager.build_session_context()
|
|
117
|
+
llm_messages = _to_llm_messages(session_ctx.messages)
|
|
118
|
+
usage = estimate_context_tokens(llm_messages)
|
|
119
|
+
self._context_tokens = usage.tokens
|
|
120
|
+
|
|
121
|
+
def get_context_usage(self) -> ContextUsage | None:
|
|
122
|
+
"""Get current context token usage and limits."""
|
|
123
|
+
self.update_context_tokens()
|
|
124
|
+
percent = (self._context_tokens / self._context_window * 100) if self._context_window else None
|
|
125
|
+
return ContextUsage(
|
|
126
|
+
tokens=self._context_tokens,
|
|
127
|
+
context_window=self._context_window,
|
|
128
|
+
percent=percent,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def get_system_prompt(self) -> str:
|
|
132
|
+
"""Get the system prompt for the agent."""
|
|
133
|
+
return self._system_prompt
|
|
134
|
+
|
|
135
|
+
async def wait_for_idle(self) -> None:
|
|
136
|
+
"""Wait until the agent becomes idle."""
|
|
137
|
+
await self._engine.wait_for_idle()
|
|
138
|
+
|
|
139
|
+
async def new_session(self) -> None:
|
|
140
|
+
"""Create a new session."""
|
|
141
|
+
if self._runtime is not None:
|
|
142
|
+
await self._runtime.new_session()
|
|
143
|
+
|
|
144
|
+
async def fork(self, entry_id: str) -> None:
|
|
145
|
+
"""Fork a session from a specific entry."""
|
|
146
|
+
if self._runtime is not None:
|
|
147
|
+
await self._runtime.fork_session(entry_id)
|
|
148
|
+
|
|
149
|
+
async def switch_session(self, session_file: Path) -> None:
|
|
150
|
+
"""Switch to a different session."""
|
|
151
|
+
if self._runtime is not None:
|
|
152
|
+
await self._runtime.resume_session(session_file)
|
|
153
|
+
|
|
154
|
+
# -------------------------------------------------------------------------
|
|
155
|
+
# Engine-level tool hooks (pass-through)
|
|
156
|
+
# -------------------------------------------------------------------------
|
|
157
|
+
|
|
158
|
+
async def _before_tool_call(
|
|
159
|
+
self,
|
|
160
|
+
invocation: ToolInvocation,
|
|
161
|
+
signal: asyncio.Event | None,
|
|
162
|
+
) -> ToolInvocation | None:
|
|
163
|
+
return invocation
|
|
164
|
+
|
|
165
|
+
async def _after_tool_call(
|
|
166
|
+
self,
|
|
167
|
+
invocation: ToolInvocation,
|
|
168
|
+
result: ToolResult,
|
|
169
|
+
signal: asyncio.Event | None,
|
|
170
|
+
) -> ToolResult | None:
|
|
171
|
+
return result
|
|
172
|
+
|
|
173
|
+
# -------------------------------------------------------------------------
|
|
174
|
+
# Internal helpers
|
|
175
|
+
# -------------------------------------------------------------------------
|
|
176
|
+
|
|
177
|
+
async def _on_message_end(self, event: MessageEndEvent) -> None:
|
|
178
|
+
"""Persist an incoming message to the session and track token usage."""
|
|
179
|
+
message = event.message
|
|
180
|
+
if message is None:
|
|
181
|
+
return
|
|
182
|
+
match message:
|
|
183
|
+
case AssistantMessage():
|
|
184
|
+
total = message.usage.input_tokens + message.usage.output_tokens
|
|
185
|
+
if total:
|
|
186
|
+
self._context_tokens = total
|
|
187
|
+
self._session_manager.append_message(message)
|
|
188
|
+
case ToolMessage():
|
|
189
|
+
self._session_manager.append_message(message)
|
|
190
|
+
case _:
|
|
191
|
+
pass
|
|
192
|
+
|
|
193
|
+
async def _on_message_rollback(self, event: "MessageRollbackEvent") -> None:
|
|
194
|
+
"""Retract the last ``event.count`` persisted messages from the session.
|
|
195
|
+
|
|
196
|
+
Fired when an interrupted tool turn is dropped: the assistant tool-call
|
|
197
|
+
message and its tool-result message were already written, so remove them
|
|
198
|
+
to keep the session consistent with what the engine replays.
|
|
199
|
+
"""
|
|
200
|
+
for _ in range(event.count):
|
|
201
|
+
if not self._session_manager.remove_last_message():
|
|
202
|
+
break
|
|
203
|
+
|
|
204
|
+
# -------------------------------------------------------------------------
|
|
205
|
+
# Compaction
|
|
206
|
+
# -------------------------------------------------------------------------
|
|
207
|
+
|
|
208
|
+
async def compact(self, custom_instructions: str | None = None) -> bool:
|
|
209
|
+
"""Manually trigger context compaction. Returns True if compaction ran."""
|
|
210
|
+
from tau.session.compaction import prepare_compaction
|
|
211
|
+
from tau.hooks.engine import CompactionEndEvent
|
|
212
|
+
entries = self._session_manager.get_branch()
|
|
213
|
+
preparation = prepare_compaction(entries, self._config.compaction)
|
|
214
|
+
if preparation is None:
|
|
215
|
+
return False
|
|
216
|
+
result, from_extension = await self._run_compaction(preparation, entries, manual=True, custom_instructions=custom_instructions)
|
|
217
|
+
self._session_manager.append_compaction(
|
|
218
|
+
summary=result.summary,
|
|
219
|
+
first_kept_entry_id=result.first_kept_entry_id,
|
|
220
|
+
tokens_before=result.tokens_before,
|
|
221
|
+
)
|
|
222
|
+
self._compaction_failures = 0
|
|
223
|
+
await self.hooks.emit(CompactionEndEvent(
|
|
224
|
+
manual=True,
|
|
225
|
+
tokens_before=result.tokens_before,
|
|
226
|
+
summary_length=len(result.summary),
|
|
227
|
+
from_extension=from_extension,
|
|
228
|
+
))
|
|
229
|
+
return True
|
|
230
|
+
|
|
231
|
+
async def _check_compaction(self) -> None:
|
|
232
|
+
"""Auto-compact if context usage exceeds the threshold. Circuit-breaks after 3 failures."""
|
|
233
|
+
from tau.session.compaction import (
|
|
234
|
+
estimate_context_tokens, should_compact,
|
|
235
|
+
prepare_compaction,
|
|
236
|
+
)
|
|
237
|
+
from tau.hooks.engine import CompactionEndEvent
|
|
238
|
+
|
|
239
|
+
if self._compaction_failures >= 3:
|
|
240
|
+
return
|
|
241
|
+
|
|
242
|
+
settings = self._config.compaction
|
|
243
|
+
if not settings.enabled:
|
|
244
|
+
return
|
|
245
|
+
|
|
246
|
+
entries = self._session_manager.get_branch()
|
|
247
|
+
session_ctx = self._session_manager.build_session_context()
|
|
248
|
+
llm_messages = _to_llm_messages(session_ctx.messages)
|
|
249
|
+
|
|
250
|
+
usage = estimate_context_tokens(llm_messages)
|
|
251
|
+
if not should_compact(usage.tokens, self._context_window, settings):
|
|
252
|
+
return
|
|
253
|
+
|
|
254
|
+
preparation = prepare_compaction(entries, settings)
|
|
255
|
+
if preparation is None:
|
|
256
|
+
return
|
|
257
|
+
|
|
258
|
+
try:
|
|
259
|
+
result, from_extension = await self._run_compaction(preparation, entries, manual=False)
|
|
260
|
+
self._session_manager.append_compaction(
|
|
261
|
+
summary=result.summary,
|
|
262
|
+
first_kept_entry_id=result.first_kept_entry_id,
|
|
263
|
+
tokens_before=result.tokens_before,
|
|
264
|
+
)
|
|
265
|
+
self._compaction_failures = 0
|
|
266
|
+
await self.hooks.emit(CompactionEndEvent(
|
|
267
|
+
manual=False,
|
|
268
|
+
tokens_before=result.tokens_before,
|
|
269
|
+
summary_length=len(result.summary),
|
|
270
|
+
from_extension=from_extension,
|
|
271
|
+
))
|
|
272
|
+
except Exception:
|
|
273
|
+
self._compaction_failures += 1
|
|
274
|
+
|
|
275
|
+
async def _run_compaction(self, preparation: "CompactionPreparation", entries: list, manual: bool, custom_instructions: str | None = None) -> tuple:
|
|
276
|
+
"""Emit before_compaction (allowing interception), then run the default algorithm.
|
|
277
|
+
|
|
278
|
+
Returns (CompactionResult, from_extension: bool).
|
|
279
|
+
Extensions may cancel (raises RuntimeError) or supply a custom CompactionResult.
|
|
280
|
+
Exceptions in before_compaction handlers are swallowed — first non-error result wins,
|
|
281
|
+
consistent with error-fallthrough behaviour.
|
|
282
|
+
"""
|
|
283
|
+
from tau.session.compaction import compact as _compact
|
|
284
|
+
from tau.hooks.types import BeforeCompactionEvent, BeforeCompactionResult, CompactionStartEvent
|
|
285
|
+
|
|
286
|
+
before_results = await self.hooks.emit(BeforeCompactionEvent(
|
|
287
|
+
preparation=preparation,
|
|
288
|
+
entries=entries,
|
|
289
|
+
manual=manual,
|
|
290
|
+
))
|
|
291
|
+
|
|
292
|
+
for res in before_results:
|
|
293
|
+
if not isinstance(res, BeforeCompactionResult):
|
|
294
|
+
continue
|
|
295
|
+
if res.cancel:
|
|
296
|
+
raise RuntimeError("Compaction cancelled by extension")
|
|
297
|
+
if res.compaction is not None:
|
|
298
|
+
return res.compaction, True
|
|
299
|
+
|
|
300
|
+
await self.hooks.emit(CompactionStartEvent(manual=manual))
|
|
301
|
+
result = await _compact(preparation, self._engine.llm, custom_instructions=custom_instructions) # type: ignore[arg-type]
|
|
302
|
+
return result, False
|
|
303
|
+
|
|
304
|
+
# -------------------------------------------------------------------------
|
|
305
|
+
# Core turn entry point
|
|
306
|
+
# -------------------------------------------------------------------------
|
|
307
|
+
|
|
308
|
+
async def invoke(self, text: str, options: PromptOptions | None = None) -> None:
|
|
309
|
+
"""Run one user turn."""
|
|
310
|
+
if self._phase != AgentPhase.IDLE:
|
|
311
|
+
raise RuntimeError(f"Agent is busy (phase={self._phase!r}). Wait for the current operation to finish.")
|
|
312
|
+
|
|
313
|
+
opts = options or PromptOptions()
|
|
314
|
+
|
|
315
|
+
session_ctx = self._session_manager.build_session_context()
|
|
316
|
+
llm_messages = _to_llm_messages(session_ctx.messages)
|
|
317
|
+
llm_messages = strip_unusable_trailing_assistant(llm_messages, self._session_manager)
|
|
318
|
+
|
|
319
|
+
if opts.images:
|
|
320
|
+
user_message = UserMessage.with_images(text, list(opts.images))
|
|
321
|
+
elif opts.audio:
|
|
322
|
+
user_message = UserMessage.with_audio(text, list(opts.audio))
|
|
323
|
+
elif opts.video:
|
|
324
|
+
user_message = UserMessage.with_video(text, list(opts.video))
|
|
325
|
+
else:
|
|
326
|
+
user_message = UserMessage.from_text(text)
|
|
327
|
+
self._session_manager.append_message(user_message, meta=opts.meta)
|
|
328
|
+
llm_messages.append(user_message)
|
|
329
|
+
|
|
330
|
+
ctx = AgentContext(
|
|
331
|
+
system_prompt=self._system_prompt,
|
|
332
|
+
messages=llm_messages,
|
|
333
|
+
tools=self._engine.tools,
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
self._signal = asyncio.Event()
|
|
337
|
+
self._engine.llm.api.options.signal = self._signal
|
|
338
|
+
|
|
339
|
+
self._phase = AgentPhase.TURN
|
|
340
|
+
try:
|
|
341
|
+
await self._run(ctx)
|
|
342
|
+
finally:
|
|
343
|
+
self._phase = AgentPhase.IDLE
|
|
344
|
+
|
|
345
|
+
await self.hooks.emit(SavePointEvent())
|
|
346
|
+
|
|
347
|
+
await self._check_compaction()
|
|
348
|
+
|
|
349
|
+
if not self._engine.has_pending_messages():
|
|
350
|
+
await self.hooks.emit(SettledEvent())
|
|
351
|
+
|
|
352
|
+
async def _run(self, ctx: AgentContext) -> None:
|
|
353
|
+
unsubscribe = self.hooks.register(
|
|
354
|
+
'message_end',
|
|
355
|
+
lambda event: self._on_message_end(event),
|
|
356
|
+
)
|
|
357
|
+
unsubscribe_rollback = self.hooks.register(
|
|
358
|
+
'message_rollback',
|
|
359
|
+
lambda event: self._on_message_rollback(event),
|
|
360
|
+
)
|
|
361
|
+
try:
|
|
362
|
+
await self._engine.run(ctx, signal=self._signal)
|
|
363
|
+
finally:
|
|
364
|
+
unsubscribe()
|
|
365
|
+
unsubscribe_rollback()
|
|
366
|
+
|
|
367
|
+
error = self._engine.state.error_message
|
|
368
|
+
if error is not None:
|
|
369
|
+
raise RuntimeError(f"Agent failed: {error}.")
|
tau/agent/types.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
from tau.message.types import LLMMessage
|
|
10
|
+
from tau.session.types import MessageMeta
|
|
11
|
+
from tau.session.compaction import CompactionSettings, DEFAULT_COMPACTION_SETTINGS
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from tau.tool.types import Tool
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AgentPhase(str, Enum):
|
|
18
|
+
"""Agent execution phase."""
|
|
19
|
+
IDLE = "idle"
|
|
20
|
+
TURN = "turn"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class AgentContext:
|
|
25
|
+
"""Snapshot of everything the LLM receives for one turn."""
|
|
26
|
+
system_prompt: str
|
|
27
|
+
messages: list[LLMMessage]
|
|
28
|
+
tools: list[Tool] = field(default_factory=list)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AgentConfig(BaseModel):
|
|
32
|
+
"""Internal runtime config passed to Agent.__init__."""
|
|
33
|
+
model_config = {'arbitrary_types_allowed': True}
|
|
34
|
+
|
|
35
|
+
cwd: Path
|
|
36
|
+
system_prompt: str = ""
|
|
37
|
+
model: Any | None = None
|
|
38
|
+
context_window: int = 200_000
|
|
39
|
+
compaction: CompactionSettings = DEFAULT_COMPACTION_SETTINGS
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class PromptOptions(BaseModel):
|
|
45
|
+
"""Configuration options for prompt submission."""
|
|
46
|
+
model_config = {'arbitrary_types_allowed': True}
|
|
47
|
+
|
|
48
|
+
meta: MessageMeta | None = None
|
|
49
|
+
images: list[bytes] = []
|
|
50
|
+
audio: list[bytes] = []
|
|
51
|
+
video: list[bytes] = []
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class ContextUsage:
|
|
56
|
+
"""Token usage and context window statistics."""
|
|
57
|
+
tokens: int
|
|
58
|
+
context_window: int
|
|
59
|
+
percent: float | None = None
|
|
60
|
+
|
|
61
|
+
|
tau/auth/manager.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import List
|
|
6
|
+
from tau.inference.provider.registry import ProviderRegistry
|
|
7
|
+
from tau.inference.provider.oauth import OAuthLoginCallbacks
|
|
8
|
+
from tau.settings.paths import get_auth_path
|
|
9
|
+
from tau.auth.types import AuthCredential, AuthStatus, OAuthCredential, APICredential, AuthType, LockResult
|
|
10
|
+
from tau.auth.storage import AuthStorage, FileAuthStorage, InMemoryAuthStorage
|
|
11
|
+
from tau.utils.secrets import resolve_secret
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _get_env_api_key(provider: str) -> str | None:
|
|
15
|
+
"""Get API key for a provider from environment variables."""
|
|
16
|
+
return os.environ.get(f"{provider.upper()}_API_KEY")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AuthManager:
|
|
20
|
+
"""Credential storage with pluggable backends."""
|
|
21
|
+
|
|
22
|
+
def __init__(self, registry: ProviderRegistry, storage: AuthStorage):
|
|
23
|
+
self.registry = registry
|
|
24
|
+
self.storage = storage
|
|
25
|
+
self.runtime_overrides: dict[str, str] = {}
|
|
26
|
+
self._load_error: Exception | None = None
|
|
27
|
+
self._errors: list[Exception] = []
|
|
28
|
+
self.data: dict[str, AuthCredential] = self._load()
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def create(registry: ProviderRegistry, auth_path: Path | None = None) -> "AuthManager":
|
|
32
|
+
"""Create AuthManager with file storage."""
|
|
33
|
+
path = auth_path or get_auth_path()
|
|
34
|
+
storage = FileAuthStorage(path)
|
|
35
|
+
return AuthManager(registry, storage)
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def from_storage(registry: ProviderRegistry, storage: AuthStorage) -> "AuthManager":
|
|
39
|
+
"""Create AuthManager with custom storage."""
|
|
40
|
+
return AuthManager(registry, storage)
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def in_memory(registry: ProviderRegistry, initial: dict = {}) -> "AuthManager":
|
|
44
|
+
"""Create AuthManager with in-memory storage for testing."""
|
|
45
|
+
storage = InMemoryAuthStorage()
|
|
46
|
+
storage.with_lock(lambda _: LockResult(result=None, next=json.dumps(initial, indent=2)))
|
|
47
|
+
return AuthManager.from_storage(registry, storage)
|
|
48
|
+
|
|
49
|
+
def _record_error(self, error: Exception) -> None:
|
|
50
|
+
"""Record an error for later retrieval."""
|
|
51
|
+
self._errors.append(error)
|
|
52
|
+
|
|
53
|
+
def _parse_storage_data(self, content: str | None) -> dict[str, AuthCredential]:
|
|
54
|
+
"""Parse credential data from storage JSON."""
|
|
55
|
+
if not content:
|
|
56
|
+
return {}
|
|
57
|
+
raw_data = json.loads(content)
|
|
58
|
+
data: dict[str, AuthCredential] = {}
|
|
59
|
+
for k, v in raw_data.items():
|
|
60
|
+
cred_type = v.get("type")
|
|
61
|
+
match cred_type:
|
|
62
|
+
case AuthType.OAuth:
|
|
63
|
+
raw_extra = v.get("extra") or {}
|
|
64
|
+
extra = {str(ek): str(ev) for ek, ev in raw_extra.items()} if isinstance(raw_extra, dict) else {}
|
|
65
|
+
data[k] = OAuthCredential(
|
|
66
|
+
access=v.get("access", ""),
|
|
67
|
+
refresh=v.get("refresh", ""),
|
|
68
|
+
expires=v.get("expires", 0),
|
|
69
|
+
extra=extra,
|
|
70
|
+
)
|
|
71
|
+
case AuthType.ApiKey:
|
|
72
|
+
data[k] = APICredential(key=v.get("key", ""))
|
|
73
|
+
return data
|
|
74
|
+
|
|
75
|
+
def _load(self) -> dict[str, AuthCredential]:
|
|
76
|
+
"""Load credentials from storage."""
|
|
77
|
+
try:
|
|
78
|
+
result = self.storage.with_lock(lambda current: LockResult(result=current))
|
|
79
|
+
self._load_error = None
|
|
80
|
+
return self._parse_storage_data(result.result)
|
|
81
|
+
except Exception as e:
|
|
82
|
+
self._load_error = e
|
|
83
|
+
self._record_error(e)
|
|
84
|
+
return {}
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def _serialize_credential(credential: AuthCredential) -> dict:
|
|
88
|
+
"""Serialize a credential to storable dict format."""
|
|
89
|
+
if isinstance(credential, OAuthCredential):
|
|
90
|
+
return {
|
|
91
|
+
"type": AuthType.OAuth,
|
|
92
|
+
"access": credential.access,
|
|
93
|
+
"refresh": credential.refresh,
|
|
94
|
+
"expires": credential.expires,
|
|
95
|
+
"extra": dict(credential.extra),
|
|
96
|
+
}
|
|
97
|
+
return {"type": AuthType.ApiKey, "key": credential.key}
|
|
98
|
+
|
|
99
|
+
def _persist_provider_change(self, provider: str, credential: AuthCredential | None) -> None:
|
|
100
|
+
"""Persist a credential change to storage."""
|
|
101
|
+
if self._load_error:
|
|
102
|
+
return
|
|
103
|
+
|
|
104
|
+
def update_fn(current: str | None) -> LockResult:
|
|
105
|
+
"""Update storage data with new credential."""
|
|
106
|
+
current_data = self._parse_storage_data(current)
|
|
107
|
+
merged = {k: self._serialize_credential(v) for k, v in current_data.items()}
|
|
108
|
+
if credential:
|
|
109
|
+
merged[provider] = self._serialize_credential(credential)
|
|
110
|
+
else:
|
|
111
|
+
merged.pop(provider, None)
|
|
112
|
+
return LockResult(result=None, next=json.dumps(merged, indent=2))
|
|
113
|
+
|
|
114
|
+
try:
|
|
115
|
+
self.storage.with_lock(update_fn)
|
|
116
|
+
except Exception as e:
|
|
117
|
+
self._record_error(e)
|
|
118
|
+
|
|
119
|
+
def reload(self) -> None:
|
|
120
|
+
"""Reload credentials from storage."""
|
|
121
|
+
self.data = self._load()
|
|
122
|
+
|
|
123
|
+
def get(self, provider: str) -> AuthCredential | None:
|
|
124
|
+
"""Return the stored credential for a provider, or None if not found."""
|
|
125
|
+
return self.data.get(provider)
|
|
126
|
+
|
|
127
|
+
def has(self, provider: str) -> bool:
|
|
128
|
+
"""Check if credentials exist for a provider in storage."""
|
|
129
|
+
return provider in self.data
|
|
130
|
+
|
|
131
|
+
def list(self) -> list[str]:
|
|
132
|
+
"""List all providers with stored credentials."""
|
|
133
|
+
return list(self.data.keys())
|
|
134
|
+
|
|
135
|
+
def set(self, provider: str, credential: AuthCredential) -> None:
|
|
136
|
+
"""Store a credential for a provider and persist to storage."""
|
|
137
|
+
self.data[provider] = credential
|
|
138
|
+
self._persist_provider_change(provider=provider, credential=credential)
|
|
139
|
+
|
|
140
|
+
def remove(self, provider: str) -> None:
|
|
141
|
+
"""Remove the stored credential for a provider and persist to storage."""
|
|
142
|
+
self.data.pop(provider, None)
|
|
143
|
+
self._persist_provider_change(provider=provider, credential=None)
|
|
144
|
+
|
|
145
|
+
def set_runtime_api_key(self, provider: str, api_key: str) -> None:
|
|
146
|
+
"""Set a runtime API key override (not persisted)."""
|
|
147
|
+
self.runtime_overrides[provider] = api_key
|
|
148
|
+
|
|
149
|
+
def remove_runtime_api_key(self, provider: str) -> None:
|
|
150
|
+
"""Remove a runtime API key override."""
|
|
151
|
+
self.runtime_overrides.pop(provider, None)
|
|
152
|
+
|
|
153
|
+
def get_auth_status(self, provider: str) -> AuthStatus:
|
|
154
|
+
"""Return auth status without exposing credential values."""
|
|
155
|
+
if self.has(provider):
|
|
156
|
+
return AuthStatus(configured=True, source="stored")
|
|
157
|
+
if provider in self.runtime_overrides:
|
|
158
|
+
return AuthStatus(configured=True, source="runtime", label="--api-key")
|
|
159
|
+
env_key = f"{provider.upper()}_API_KEY"
|
|
160
|
+
if os.environ.get(env_key):
|
|
161
|
+
return AuthStatus(configured=True, source="env", label=env_key)
|
|
162
|
+
return AuthStatus(configured=False)
|
|
163
|
+
|
|
164
|
+
def drain_errors(self) -> List[Exception]:
|
|
165
|
+
"""Return and clear accumulated errors."""
|
|
166
|
+
drained = list(self._errors)
|
|
167
|
+
self._errors.clear()
|
|
168
|
+
return drained
|
|
169
|
+
|
|
170
|
+
async def get_api_key(self, provider: str) -> str | None:
|
|
171
|
+
"""Get an API key for a provider, refreshing OAuth tokens if needed."""
|
|
172
|
+
# 1. Runtime override
|
|
173
|
+
if provider in self.runtime_overrides:
|
|
174
|
+
return resolve_secret(self.runtime_overrides[provider])
|
|
175
|
+
|
|
176
|
+
credential = self.get(provider)
|
|
177
|
+
|
|
178
|
+
match credential:
|
|
179
|
+
case APICredential():
|
|
180
|
+
# The stored key may be a literal, "$ENV_VAR", or "!command";
|
|
181
|
+
# resolved once and cached (see tau.utils.secrets).
|
|
182
|
+
return resolve_secret(credential.key)
|
|
183
|
+
case OAuthCredential():
|
|
184
|
+
oauth_provider = self.registry.text.get_oauth_provider(provider=provider)
|
|
185
|
+
if not oauth_provider:
|
|
186
|
+
return None
|
|
187
|
+
|
|
188
|
+
if oauth_provider.is_expired(credential=credential):
|
|
189
|
+
refreshed_credential = await self._refresh_oauth_token_with_lock(provider=provider)
|
|
190
|
+
if refreshed_credential:
|
|
191
|
+
credential = refreshed_credential
|
|
192
|
+
else:
|
|
193
|
+
return None
|
|
194
|
+
return oauth_provider.get_api_key(credential=credential)
|
|
195
|
+
|
|
196
|
+
# 2. Environment variable fallback
|
|
197
|
+
return _get_env_api_key(provider)
|
|
198
|
+
|
|
199
|
+
async def _refresh_oauth_token_with_lock(self, provider: str) -> OAuthCredential | None:
|
|
200
|
+
"""Refresh an expired OAuth token with file locking to prevent race conditions."""
|
|
201
|
+
oauth_provider = self.registry.text.get_oauth_provider(provider=provider)
|
|
202
|
+
if not oauth_provider:
|
|
203
|
+
return None
|
|
204
|
+
|
|
205
|
+
async def refresh_fn(current: str | None) -> LockResult:
|
|
206
|
+
"""Refresh OAuth token in storage."""
|
|
207
|
+
current_data = self._parse_storage_data(current)
|
|
208
|
+
credential = current_data.get(provider)
|
|
209
|
+
|
|
210
|
+
if not isinstance(credential, OAuthCredential):
|
|
211
|
+
return LockResult(result=None)
|
|
212
|
+
|
|
213
|
+
# Check if another instance already refreshed
|
|
214
|
+
if not oauth_provider.is_expired(credential=credential):
|
|
215
|
+
return LockResult(result=credential)
|
|
216
|
+
|
|
217
|
+
try:
|
|
218
|
+
refreshed_credential = await oauth_provider.refresh_token(credential=credential)
|
|
219
|
+
if credential.extra:
|
|
220
|
+
merged_extra = dict(credential.extra)
|
|
221
|
+
merged_extra.update(refreshed_credential.extra)
|
|
222
|
+
refreshed_credential.extra = merged_extra
|
|
223
|
+
current_data[provider] = refreshed_credential
|
|
224
|
+
self.data = current_data
|
|
225
|
+
serialized = {k: self._serialize_credential(v) for k, v in current_data.items()}
|
|
226
|
+
return LockResult(result=refreshed_credential, next=json.dumps(serialized, indent=2))
|
|
227
|
+
except Exception as e:
|
|
228
|
+
self._record_error(e)
|
|
229
|
+
return LockResult(result=None)
|
|
230
|
+
|
|
231
|
+
result = await self.storage.with_lock_async(refresh_fn)
|
|
232
|
+
return result.result
|
|
233
|
+
|
|
234
|
+
async def login(self, provider: str, callbacks: OAuthLoginCallbacks):
|
|
235
|
+
"""Perform OAuth login flow for a provider and store the resulting credential."""
|
|
236
|
+
if oauth_provider := self.registry.text.get_oauth_provider(provider):
|
|
237
|
+
credential = await oauth_provider.login(callbacks=callbacks)
|
|
238
|
+
self.data[provider] = credential
|
|
239
|
+
self._persist_provider_change(provider, credential)
|
|
240
|
+
|
|
241
|
+
async def logout(self, provider: str):
|
|
242
|
+
"""Perform OAuth logout for a provider and remove the stored credential."""
|
|
243
|
+
if oauth_provider := self.registry.text.get_oauth_provider(provider):
|
|
244
|
+
if credential := self.get(provider):
|
|
245
|
+
if isinstance(credential, OAuthCredential):
|
|
246
|
+
await oauth_provider.logout(credential=credential)
|
|
247
|
+
self.remove(provider)
|