tau-coding-agent 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (283) hide show
  1. tau/__init__.py +0 -0
  2. tau/agent/__init__.py +11 -0
  3. tau/agent/prompt/__init__.py +10 -0
  4. tau/agent/prompt/builder.py +302 -0
  5. tau/agent/prompt/types.py +33 -0
  6. tau/agent/service.py +369 -0
  7. tau/agent/types.py +61 -0
  8. tau/auth/manager.py +247 -0
  9. tau/auth/storage.py +82 -0
  10. tau/auth/types.py +41 -0
  11. tau/builtins/__init__.py +4 -0
  12. tau/builtins/__pycache__/__init__.cpython-313.pyc +0 -0
  13. tau/builtins/__pycache__/__init__.cpython-314.pyc +0 -0
  14. tau/builtins/commands/__init__.py +41 -0
  15. tau/builtins/commands/__pycache__/__init__.cpython-313.pyc +0 -0
  16. tau/builtins/commands/__pycache__/__init__.cpython-314.pyc +0 -0
  17. tau/builtins/commands/__pycache__/clear.cpython-313.pyc +0 -0
  18. tau/builtins/commands/__pycache__/clear.cpython-314.pyc +0 -0
  19. tau/builtins/commands/__pycache__/compact.cpython-313.pyc +0 -0
  20. tau/builtins/commands/__pycache__/compact.cpython-314.pyc +0 -0
  21. tau/builtins/commands/__pycache__/reload.cpython-313.pyc +0 -0
  22. tau/builtins/commands/__pycache__/reload.cpython-314.pyc +0 -0
  23. tau/builtins/commands/__pycache__/session.cpython-313.pyc +0 -0
  24. tau/builtins/commands/__pycache__/session.cpython-314.pyc +0 -0
  25. tau/builtins/commands/clear.py +16 -0
  26. tau/builtins/commands/compact.py +28 -0
  27. tau/builtins/commands/reload.py +27 -0
  28. tau/builtins/commands/session.py +19 -0
  29. tau/builtins/extensions/footer/__init__.py +76 -0
  30. tau/builtins/extensions/footer/__pycache__/__init__.cpython-313.pyc +0 -0
  31. tau/builtins/extensions/footer/__pycache__/git.cpython-313.pyc +0 -0
  32. tau/builtins/extensions/footer/__pycache__/model.cpython-313.pyc +0 -0
  33. tau/builtins/extensions/footer/__pycache__/utils.cpython-313.pyc +0 -0
  34. tau/builtins/extensions/footer/git.py +26 -0
  35. tau/builtins/extensions/footer/model.py +69 -0
  36. tau/builtins/extensions/footer/utils.py +44 -0
  37. tau/builtins/extensions/header/__init__.py +18 -0
  38. tau/builtins/extensions/header/__pycache__/__init__.cpython-313.pyc +0 -0
  39. tau/builtins/models/__init__.py +0 -0
  40. tau/builtins/models/__pycache__/__init__.cpython-313.pyc +0 -0
  41. tau/builtins/models/__pycache__/text.cpython-313.pyc +0 -0
  42. tau/builtins/models/audio.py +43 -0
  43. tau/builtins/models/image.py +43 -0
  44. tau/builtins/models/text.py +482 -0
  45. tau/builtins/models/video.py +40 -0
  46. tau/builtins/prompts/commit.md +7 -0
  47. tau/builtins/prompts/docs.md +7 -0
  48. tau/builtins/prompts/explain.md +7 -0
  49. tau/builtins/prompts/fix.md +7 -0
  50. tau/builtins/prompts/refactor.md +7 -0
  51. tau/builtins/prompts/review.md +7 -0
  52. tau/builtins/prompts/test.md +7 -0
  53. tau/builtins/providers/__init__.py +0 -0
  54. tau/builtins/providers/__pycache__/__init__.cpython-313.pyc +0 -0
  55. tau/builtins/providers/__pycache__/text.cpython-313.pyc +0 -0
  56. tau/builtins/providers/audio.py +10 -0
  57. tau/builtins/providers/image.py +9 -0
  58. tau/builtins/providers/text.py +33 -0
  59. tau/builtins/providers/video.py +6 -0
  60. tau/builtins/skills/code-review/SKILL.md +4 -0
  61. tau/builtins/skills/debug/SKILL.md +4 -0
  62. tau/builtins/skills/git-commit/SKILL.md +4 -0
  63. tau/builtins/themes/dark.yaml +1 -0
  64. tau/builtins/themes/light.yaml +46 -0
  65. tau/builtins/tools/__init__.py +73 -0
  66. tau/builtins/tools/__pycache__/__init__.cpython-313.pyc +0 -0
  67. tau/builtins/tools/__pycache__/__init__.cpython-314.pyc +0 -0
  68. tau/builtins/tools/__pycache__/bash.cpython-313.pyc +0 -0
  69. tau/builtins/tools/__pycache__/bash.cpython-314.pyc +0 -0
  70. tau/builtins/tools/__pycache__/edit.cpython-313.pyc +0 -0
  71. tau/builtins/tools/__pycache__/edit.cpython-314.pyc +0 -0
  72. tau/builtins/tools/__pycache__/glob.cpython-313.pyc +0 -0
  73. tau/builtins/tools/__pycache__/glob.cpython-314.pyc +0 -0
  74. tau/builtins/tools/__pycache__/grep.cpython-313.pyc +0 -0
  75. tau/builtins/tools/__pycache__/grep.cpython-314.pyc +0 -0
  76. tau/builtins/tools/__pycache__/ls.cpython-313.pyc +0 -0
  77. tau/builtins/tools/__pycache__/ls.cpython-314.pyc +0 -0
  78. tau/builtins/tools/__pycache__/read.cpython-313.pyc +0 -0
  79. tau/builtins/tools/__pycache__/read.cpython-314.pyc +0 -0
  80. tau/builtins/tools/__pycache__/terminal.cpython-313.pyc +0 -0
  81. tau/builtins/tools/__pycache__/terminal.cpython-314.pyc +0 -0
  82. tau/builtins/tools/__pycache__/write.cpython-313.pyc +0 -0
  83. tau/builtins/tools/__pycache__/write.cpython-314.pyc +0 -0
  84. tau/builtins/tools/edit.py +215 -0
  85. tau/builtins/tools/glob.py +112 -0
  86. tau/builtins/tools/grep.py +146 -0
  87. tau/builtins/tools/ls.py +135 -0
  88. tau/builtins/tools/read.py +122 -0
  89. tau/builtins/tools/terminal.py +150 -0
  90. tau/builtins/tools/write.py +105 -0
  91. tau/commands/__init__.py +10 -0
  92. tau/commands/registry.py +71 -0
  93. tau/commands/types.py +33 -0
  94. tau/console/__init__.py +0 -0
  95. tau/console/cli.py +266 -0
  96. tau/console/commands/__init__.py +0 -0
  97. tau/console/commands/auth.py +193 -0
  98. tau/console/commands/packages.py +104 -0
  99. tau/console/commands/update.py +76 -0
  100. tau/core/__init__.py +0 -0
  101. tau/core/registry.py +102 -0
  102. tau/engine/__init__.py +47 -0
  103. tau/engine/service.py +768 -0
  104. tau/engine/types.py +163 -0
  105. tau/extensions/__init__.py +28 -0
  106. tau/extensions/api.py +928 -0
  107. tau/extensions/context.py +462 -0
  108. tau/extensions/events.py +70 -0
  109. tau/extensions/loader.py +386 -0
  110. tau/extensions/runtime.py +184 -0
  111. tau/extensions/settings.py +137 -0
  112. tau/hooks/__init__.py +112 -0
  113. tau/hooks/engine.py +237 -0
  114. tau/hooks/inference.py +21 -0
  115. tau/hooks/runtime.py +126 -0
  116. tau/hooks/service.py +121 -0
  117. tau/hooks/session.py +117 -0
  118. tau/hooks/tui.py +61 -0
  119. tau/hooks/types.py +72 -0
  120. tau/inference/__init__.py +180 -0
  121. tau/inference/api/__init__.py +0 -0
  122. tau/inference/api/audio/__init__.py +0 -0
  123. tau/inference/api/audio/base.py +29 -0
  124. tau/inference/api/audio/builtins.py +15 -0
  125. tau/inference/api/audio/elevenlabs_audio.py +183 -0
  126. tau/inference/api/audio/gemini_audio.py +95 -0
  127. tau/inference/api/audio/openai_audio.py +159 -0
  128. tau/inference/api/audio/registry.py +15 -0
  129. tau/inference/api/audio/sarvam_audio.py +163 -0
  130. tau/inference/api/audio/service.py +103 -0
  131. tau/inference/api/audio/utils.py +47 -0
  132. tau/inference/api/image/__init__.py +0 -0
  133. tau/inference/api/image/base.py +17 -0
  134. tau/inference/api/image/builtins.py +8 -0
  135. tau/inference/api/image/gemini_image.py +77 -0
  136. tau/inference/api/image/openai_image.py +103 -0
  137. tau/inference/api/image/openrouter.py +144 -0
  138. tau/inference/api/image/registry.py +15 -0
  139. tau/inference/api/image/service.py +71 -0
  140. tau/inference/api/registry.py +82 -0
  141. tau/inference/api/text/__init__.py +0 -0
  142. tau/inference/api/text/anthropic_claude_code.py +222 -0
  143. tau/inference/api/text/anthropic_messages.py +196 -0
  144. tau/inference/api/text/base.py +40 -0
  145. tau/inference/api/text/builtins.py +19 -0
  146. tau/inference/api/text/gemini_generate.py +234 -0
  147. tau/inference/api/text/github_copilot_chat.py +172 -0
  148. tau/inference/api/text/google_antigravity.py +522 -0
  149. tau/inference/api/text/mistral_chat.py +284 -0
  150. tau/inference/api/text/ollama_chat.py +200 -0
  151. tau/inference/api/text/openai_codex_responses.py +497 -0
  152. tau/inference/api/text/openai_completions.py +227 -0
  153. tau/inference/api/text/openai_responses.py +235 -0
  154. tau/inference/api/text/registry.py +50 -0
  155. tau/inference/api/text/service.py +297 -0
  156. tau/inference/api/text/types.py +7 -0
  157. tau/inference/api/text/utils.py +228 -0
  158. tau/inference/api/video/__init__.py +0 -0
  159. tau/inference/api/video/base.py +26 -0
  160. tau/inference/api/video/builtins.py +7 -0
  161. tau/inference/api/video/fal_video.py +119 -0
  162. tau/inference/api/video/openrouter_video.py +142 -0
  163. tau/inference/api/video/registry.py +15 -0
  164. tau/inference/api/video/service.py +72 -0
  165. tau/inference/model/__init__.py +0 -0
  166. tau/inference/model/registry.py +102 -0
  167. tau/inference/model/types.py +65 -0
  168. tau/inference/provider/__init__.py +0 -0
  169. tau/inference/provider/oauth/__init__.py +35 -0
  170. tau/inference/provider/oauth/anthropic_claude_code.py +286 -0
  171. tau/inference/provider/oauth/github_copilot.py +333 -0
  172. tau/inference/provider/oauth/google_antigravity.py +258 -0
  173. tau/inference/provider/oauth/openai_codex.py +309 -0
  174. tau/inference/provider/oauth/pkce.py +14 -0
  175. tau/inference/provider/oauth/types.py +46 -0
  176. tau/inference/provider/oauth/utils.py +154 -0
  177. tau/inference/provider/registry.py +141 -0
  178. tau/inference/provider/types.py +114 -0
  179. tau/inference/types.py +549 -0
  180. tau/inference/utils.py +219 -0
  181. tau/message/__init__.py +0 -0
  182. tau/message/types.py +482 -0
  183. tau/message/utils.py +178 -0
  184. tau/packages/__init__.py +11 -0
  185. tau/packages/manager.py +190 -0
  186. tau/packages/types.py +20 -0
  187. tau/packages/utils.py +67 -0
  188. tau/prompts/expand.py +58 -0
  189. tau/prompts/loader.py +69 -0
  190. tau/prompts/registry.py +45 -0
  191. tau/prompts/types.py +24 -0
  192. tau/rpc/__init__.py +8 -0
  193. tau/rpc/mode.py +783 -0
  194. tau/rpc/types.py +252 -0
  195. tau/runtime/service.py +759 -0
  196. tau/runtime/types.py +303 -0
  197. tau/session/branch_summarization.py +312 -0
  198. tau/session/compaction.py +646 -0
  199. tau/session/manager.py +652 -0
  200. tau/session/types.py +188 -0
  201. tau/session/utils.py +233 -0
  202. tau/settings/manager.py +1077 -0
  203. tau/settings/paths.py +150 -0
  204. tau/settings/storage.py +63 -0
  205. tau/settings/types.py +173 -0
  206. tau/settings/utils.py +25 -0
  207. tau/skills/loader.py +91 -0
  208. tau/skills/registry.py +70 -0
  209. tau/skills/types.py +25 -0
  210. tau/themes/loader.py +238 -0
  211. tau/themes/registry.py +108 -0
  212. tau/themes/types.py +19 -0
  213. tau/tool/__init__.py +3 -0
  214. tau/tool/registry.py +117 -0
  215. tau/tool/render.py +21 -0
  216. tau/tool/types.py +244 -0
  217. tau/trust/__init__.py +13 -0
  218. tau/trust/manager.py +80 -0
  219. tau/trust/types.py +14 -0
  220. tau/trust/utils.py +72 -0
  221. tau/tui/__init__.py +54 -0
  222. tau/tui/agent_hooks.py +346 -0
  223. tau/tui/ansi.py +330 -0
  224. tau/tui/app.py +540 -0
  225. tau/tui/autocomplete.py +33 -0
  226. tau/tui/capabilities.py +119 -0
  227. tau/tui/commands/__init__.py +3 -0
  228. tau/tui/commands/appearance.py +498 -0
  229. tau/tui/commands/auth.py +232 -0
  230. tau/tui/commands/context.py +38 -0
  231. tau/tui/commands/misc.py +82 -0
  232. tau/tui/commands/model.py +118 -0
  233. tau/tui/commands/session.py +464 -0
  234. tau/tui/component.py +268 -0
  235. tau/tui/components/__init__.py +0 -0
  236. tau/tui/components/autocomplete_manager.py +267 -0
  237. tau/tui/components/autocomplete_picker.py +143 -0
  238. tau/tui/components/box.py +90 -0
  239. tau/tui/components/command_palette.py +144 -0
  240. tau/tui/components/dynamic_border.py +19 -0
  241. tau/tui/components/file_picker.py +233 -0
  242. tau/tui/components/image.py +181 -0
  243. tau/tui/components/inline_selector.py +71 -0
  244. tau/tui/components/layout.py +1194 -0
  245. tau/tui/components/message_list.py +692 -0
  246. tau/tui/components/modal.py +97 -0
  247. tau/tui/components/model_palette.py +204 -0
  248. tau/tui/components/picker_overlay.py +174 -0
  249. tau/tui/components/prompt_overlay.py +236 -0
  250. tau/tui/components/resume_modal.py +372 -0
  251. tau/tui/components/select_list.py +222 -0
  252. tau/tui/components/settings_modal.py +274 -0
  253. tau/tui/components/settings_schema.py +203 -0
  254. tau/tui/components/spinner.py +119 -0
  255. tau/tui/components/text_input.py +396 -0
  256. tau/tui/components/text_prompt.py +82 -0
  257. tau/tui/components/tree_select_list.py +580 -0
  258. tau/tui/components/trust_screen.py +97 -0
  259. tau/tui/diff.py +114 -0
  260. tau/tui/fuzzy.py +99 -0
  261. tau/tui/input.py +496 -0
  262. tau/tui/input_handler.py +716 -0
  263. tau/tui/keybindings.py +87 -0
  264. tau/tui/markdown.py +286 -0
  265. tau/tui/message_renderers.py +31 -0
  266. tau/tui/overlay.py +326 -0
  267. tau/tui/renderer.py +378 -0
  268. tau/tui/terminal.py +499 -0
  269. tau/tui/theme.py +148 -0
  270. tau/tui/tui.py +544 -0
  271. tau/tui/ui_context.py +768 -0
  272. tau/tui/utils.py +20 -0
  273. tau/utils/__init__.py +0 -0
  274. tau/utils/http_proxy.py +221 -0
  275. tau/utils/image_processing.py +172 -0
  276. tau/utils/secrets.py +59 -0
  277. tau/utils/version_check.py +60 -0
  278. tau_coding_agent-0.1.0.dist-info/METADATA +177 -0
  279. tau_coding_agent-0.1.0.dist-info/RECORD +283 -0
  280. tau_coding_agent-0.1.0.dist-info/WHEEL +5 -0
  281. tau_coding_agent-0.1.0.dist-info/entry_points.txt +2 -0
  282. tau_coding_agent-0.1.0.dist-info/licenses/LICENSE +21 -0
  283. tau_coding_agent-0.1.0.dist-info/top_level.txt +1 -0
tau/agent/service.py ADDED
@@ -0,0 +1,369 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING
6
+
7
+ from tau.agent.types import AgentConfig, AgentContext, AgentPhase, PromptOptions, ContextUsage
8
+ from tau.hooks.service import Hooks
9
+ from tau.hooks.engine import MessageEndEvent, MessageRollbackEvent, SavePointEvent, SettledEvent
10
+ from tau.message.types import AgentMessage, AssistantMessage, TerminalExecutionMessage, LLMMessage, UserMessage, TextContent, ToolMessage
11
+ from tau.message.utils import strip_unusable_trailing_assistant
12
+ from tau.tool.types import ToolInvocation, ToolResult
13
+
14
+ if TYPE_CHECKING:
15
+ from tau.engine.service import Engine
16
+ from tau.session.manager import SessionManager
17
+ from tau.runtime.service import Runtime
18
+ from tau.session.compaction import CompactionPreparation
19
+
20
+
21
+
22
+ def _to_llm_messages(messages: list[AgentMessage]) -> list[LLMMessage]:
23
+ """Convert AgentMessages to LLM-compatible messages.
24
+
25
+ TerminalExecutionMessage → UserMessage (Ran `cmd`\n```output```)
26
+ CompactionSummaryMessage → UserMessage with summary wrapped in XML tags
27
+ CustomMessage and other non-LLM types → skipped
28
+ Empty AssistantMessages are visual-only markers (aborts, persisted API/credit
29
+ errors) and are skipped — an assistant turn with neither content nor tool
30
+ calls is invalid to send back and triggers provider 400s.
31
+ """
32
+ from tau.message.types import CompactionSummaryMessage, ToolCallContent, ThinkingContent
33
+ result: list[LLMMessage] = []
34
+ for msg in messages:
35
+ if isinstance(msg, CompactionSummaryMessage):
36
+ text = f"<context-summary>\n{msg.summary}\n</context-summary>"
37
+ result.append(UserMessage.from_text(text))
38
+ elif isinstance(msg, TerminalExecutionMessage):
39
+ if not msg.exclude:
40
+ result.append(msg.to_user_message())
41
+ elif isinstance(msg, AssistantMessage):
42
+ has_usable = any(
43
+ isinstance(c, (TextContent, ToolCallContent, ThinkingContent))
44
+ for c in msg.contents
45
+ )
46
+ if has_usable:
47
+ result.append(msg)
48
+ elif isinstance(msg, (UserMessage, ToolMessage)):
49
+ result.append(msg)
50
+ return result
51
+
52
+
53
+ class Agent:
54
+ """
55
+ High-level agent session tying together Engine and SessionManager.
56
+
57
+ Call `invoke()` to run a user turn. The session persists each message
58
+ and tracks token usage.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ engine: Engine,
64
+ session_manager: SessionManager,
65
+ config: AgentConfig,
66
+ hooks: Hooks | None = None,
67
+ ) -> None:
68
+ self._engine = engine
69
+ self._session_manager = session_manager
70
+ self._config = config
71
+ self._system_prompt: str = config.system_prompt
72
+ self._context_tokens: int = 0
73
+ self._context_window: int = config.context_window
74
+ self._runtime: Runtime | None = None
75
+ self.hooks = hooks or Hooks()
76
+
77
+ self._phase: AgentPhase = AgentPhase.IDLE
78
+ self._signal: asyncio.Event = asyncio.Event()
79
+ self._compaction_failures: int = 0
80
+ self._engine.options.before_tool_call = self._before_tool_call
81
+ self._engine.options.after_tool_call = self._after_tool_call
82
+
83
+ # -------------------------------------------------------------------------
84
+ # Public interface
85
+ # -------------------------------------------------------------------------
86
+
87
+ @property
88
+ def cwd(self) -> Path:
89
+ """Get the current working directory."""
90
+ return self._config.cwd
91
+
92
+ @property
93
+ def session_manager(self) -> SessionManager:
94
+ """Get the session manager instance."""
95
+ return self._session_manager
96
+
97
+ def is_idle(self) -> bool:
98
+ """Check if the agent is idle (not processing)."""
99
+ return self._engine.is_idle
100
+
101
+ def has_pending_messages(self) -> bool:
102
+ """Check if there are pending messages in the queue."""
103
+ return self._engine.has_pending_messages()
104
+
105
+ def abort(self) -> None:
106
+ """Request abort of current operation."""
107
+ self._signal.set()
108
+
109
+ def shutdown(self) -> None:
110
+ """Shutdown the agent."""
111
+ self._signal.set()
112
+
113
+ def update_context_tokens(self) -> None:
114
+ """Recalculate context token usage."""
115
+ from tau.session.compaction import estimate_context_tokens
116
+ session_ctx = self._session_manager.build_session_context()
117
+ llm_messages = _to_llm_messages(session_ctx.messages)
118
+ usage = estimate_context_tokens(llm_messages)
119
+ self._context_tokens = usage.tokens
120
+
121
+ def get_context_usage(self) -> ContextUsage | None:
122
+ """Get current context token usage and limits."""
123
+ self.update_context_tokens()
124
+ percent = (self._context_tokens / self._context_window * 100) if self._context_window else None
125
+ return ContextUsage(
126
+ tokens=self._context_tokens,
127
+ context_window=self._context_window,
128
+ percent=percent,
129
+ )
130
+
131
+ def get_system_prompt(self) -> str:
132
+ """Get the system prompt for the agent."""
133
+ return self._system_prompt
134
+
135
+ async def wait_for_idle(self) -> None:
136
+ """Wait until the agent becomes idle."""
137
+ await self._engine.wait_for_idle()
138
+
139
+ async def new_session(self) -> None:
140
+ """Create a new session."""
141
+ if self._runtime is not None:
142
+ await self._runtime.new_session()
143
+
144
+ async def fork(self, entry_id: str) -> None:
145
+ """Fork a session from a specific entry."""
146
+ if self._runtime is not None:
147
+ await self._runtime.fork_session(entry_id)
148
+
149
+ async def switch_session(self, session_file: Path) -> None:
150
+ """Switch to a different session."""
151
+ if self._runtime is not None:
152
+ await self._runtime.resume_session(session_file)
153
+
154
+ # -------------------------------------------------------------------------
155
+ # Engine-level tool hooks (pass-through)
156
+ # -------------------------------------------------------------------------
157
+
158
+ async def _before_tool_call(
159
+ self,
160
+ invocation: ToolInvocation,
161
+ signal: asyncio.Event | None,
162
+ ) -> ToolInvocation | None:
163
+ return invocation
164
+
165
+ async def _after_tool_call(
166
+ self,
167
+ invocation: ToolInvocation,
168
+ result: ToolResult,
169
+ signal: asyncio.Event | None,
170
+ ) -> ToolResult | None:
171
+ return result
172
+
173
+ # -------------------------------------------------------------------------
174
+ # Internal helpers
175
+ # -------------------------------------------------------------------------
176
+
177
+ async def _on_message_end(self, event: MessageEndEvent) -> None:
178
+ """Persist an incoming message to the session and track token usage."""
179
+ message = event.message
180
+ if message is None:
181
+ return
182
+ match message:
183
+ case AssistantMessage():
184
+ total = message.usage.input_tokens + message.usage.output_tokens
185
+ if total:
186
+ self._context_tokens = total
187
+ self._session_manager.append_message(message)
188
+ case ToolMessage():
189
+ self._session_manager.append_message(message)
190
+ case _:
191
+ pass
192
+
193
+ async def _on_message_rollback(self, event: "MessageRollbackEvent") -> None:
194
+ """Retract the last ``event.count`` persisted messages from the session.
195
+
196
+ Fired when an interrupted tool turn is dropped: the assistant tool-call
197
+ message and its tool-result message were already written, so remove them
198
+ to keep the session consistent with what the engine replays.
199
+ """
200
+ for _ in range(event.count):
201
+ if not self._session_manager.remove_last_message():
202
+ break
203
+
204
+ # -------------------------------------------------------------------------
205
+ # Compaction
206
+ # -------------------------------------------------------------------------
207
+
208
+ async def compact(self, custom_instructions: str | None = None) -> bool:
209
+ """Manually trigger context compaction. Returns True if compaction ran."""
210
+ from tau.session.compaction import prepare_compaction
211
+ from tau.hooks.engine import CompactionEndEvent
212
+ entries = self._session_manager.get_branch()
213
+ preparation = prepare_compaction(entries, self._config.compaction)
214
+ if preparation is None:
215
+ return False
216
+ result, from_extension = await self._run_compaction(preparation, entries, manual=True, custom_instructions=custom_instructions)
217
+ self._session_manager.append_compaction(
218
+ summary=result.summary,
219
+ first_kept_entry_id=result.first_kept_entry_id,
220
+ tokens_before=result.tokens_before,
221
+ )
222
+ self._compaction_failures = 0
223
+ await self.hooks.emit(CompactionEndEvent(
224
+ manual=True,
225
+ tokens_before=result.tokens_before,
226
+ summary_length=len(result.summary),
227
+ from_extension=from_extension,
228
+ ))
229
+ return True
230
+
231
+ async def _check_compaction(self) -> None:
232
+ """Auto-compact if context usage exceeds the threshold. Circuit-breaks after 3 failures."""
233
+ from tau.session.compaction import (
234
+ estimate_context_tokens, should_compact,
235
+ prepare_compaction,
236
+ )
237
+ from tau.hooks.engine import CompactionEndEvent
238
+
239
+ if self._compaction_failures >= 3:
240
+ return
241
+
242
+ settings = self._config.compaction
243
+ if not settings.enabled:
244
+ return
245
+
246
+ entries = self._session_manager.get_branch()
247
+ session_ctx = self._session_manager.build_session_context()
248
+ llm_messages = _to_llm_messages(session_ctx.messages)
249
+
250
+ usage = estimate_context_tokens(llm_messages)
251
+ if not should_compact(usage.tokens, self._context_window, settings):
252
+ return
253
+
254
+ preparation = prepare_compaction(entries, settings)
255
+ if preparation is None:
256
+ return
257
+
258
+ try:
259
+ result, from_extension = await self._run_compaction(preparation, entries, manual=False)
260
+ self._session_manager.append_compaction(
261
+ summary=result.summary,
262
+ first_kept_entry_id=result.first_kept_entry_id,
263
+ tokens_before=result.tokens_before,
264
+ )
265
+ self._compaction_failures = 0
266
+ await self.hooks.emit(CompactionEndEvent(
267
+ manual=False,
268
+ tokens_before=result.tokens_before,
269
+ summary_length=len(result.summary),
270
+ from_extension=from_extension,
271
+ ))
272
+ except Exception:
273
+ self._compaction_failures += 1
274
+
275
+ async def _run_compaction(self, preparation: "CompactionPreparation", entries: list, manual: bool, custom_instructions: str | None = None) -> tuple:
276
+ """Emit before_compaction (allowing interception), then run the default algorithm.
277
+
278
+ Returns (CompactionResult, from_extension: bool).
279
+ Extensions may cancel (raises RuntimeError) or supply a custom CompactionResult.
280
+ Exceptions in before_compaction handlers are swallowed — first non-error result wins,
281
+ consistent with error-fallthrough behaviour.
282
+ """
283
+ from tau.session.compaction import compact as _compact
284
+ from tau.hooks.types import BeforeCompactionEvent, BeforeCompactionResult, CompactionStartEvent
285
+
286
+ before_results = await self.hooks.emit(BeforeCompactionEvent(
287
+ preparation=preparation,
288
+ entries=entries,
289
+ manual=manual,
290
+ ))
291
+
292
+ for res in before_results:
293
+ if not isinstance(res, BeforeCompactionResult):
294
+ continue
295
+ if res.cancel:
296
+ raise RuntimeError("Compaction cancelled by extension")
297
+ if res.compaction is not None:
298
+ return res.compaction, True
299
+
300
+ await self.hooks.emit(CompactionStartEvent(manual=manual))
301
+ result = await _compact(preparation, self._engine.llm, custom_instructions=custom_instructions) # type: ignore[arg-type]
302
+ return result, False
303
+
304
+ # -------------------------------------------------------------------------
305
+ # Core turn entry point
306
+ # -------------------------------------------------------------------------
307
+
308
+ async def invoke(self, text: str, options: PromptOptions | None = None) -> None:
309
+ """Run one user turn."""
310
+ if self._phase != AgentPhase.IDLE:
311
+ raise RuntimeError(f"Agent is busy (phase={self._phase!r}). Wait for the current operation to finish.")
312
+
313
+ opts = options or PromptOptions()
314
+
315
+ session_ctx = self._session_manager.build_session_context()
316
+ llm_messages = _to_llm_messages(session_ctx.messages)
317
+ llm_messages = strip_unusable_trailing_assistant(llm_messages, self._session_manager)
318
+
319
+ if opts.images:
320
+ user_message = UserMessage.with_images(text, list(opts.images))
321
+ elif opts.audio:
322
+ user_message = UserMessage.with_audio(text, list(opts.audio))
323
+ elif opts.video:
324
+ user_message = UserMessage.with_video(text, list(opts.video))
325
+ else:
326
+ user_message = UserMessage.from_text(text)
327
+ self._session_manager.append_message(user_message, meta=opts.meta)
328
+ llm_messages.append(user_message)
329
+
330
+ ctx = AgentContext(
331
+ system_prompt=self._system_prompt,
332
+ messages=llm_messages,
333
+ tools=self._engine.tools,
334
+ )
335
+
336
+ self._signal = asyncio.Event()
337
+ self._engine.llm.api.options.signal = self._signal
338
+
339
+ self._phase = AgentPhase.TURN
340
+ try:
341
+ await self._run(ctx)
342
+ finally:
343
+ self._phase = AgentPhase.IDLE
344
+
345
+ await self.hooks.emit(SavePointEvent())
346
+
347
+ await self._check_compaction()
348
+
349
+ if not self._engine.has_pending_messages():
350
+ await self.hooks.emit(SettledEvent())
351
+
352
+ async def _run(self, ctx: AgentContext) -> None:
353
+ unsubscribe = self.hooks.register(
354
+ 'message_end',
355
+ lambda event: self._on_message_end(event),
356
+ )
357
+ unsubscribe_rollback = self.hooks.register(
358
+ 'message_rollback',
359
+ lambda event: self._on_message_rollback(event),
360
+ )
361
+ try:
362
+ await self._engine.run(ctx, signal=self._signal)
363
+ finally:
364
+ unsubscribe()
365
+ unsubscribe_rollback()
366
+
367
+ error = self._engine.state.error_message
368
+ if error is not None:
369
+ raise RuntimeError(f"Agent failed: {error}.")
tau/agent/types.py ADDED
@@ -0,0 +1,61 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from enum import Enum
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ from pydantic import BaseModel
9
+ from tau.message.types import LLMMessage
10
+ from tau.session.types import MessageMeta
11
+ from tau.session.compaction import CompactionSettings, DEFAULT_COMPACTION_SETTINGS
12
+
13
+ if TYPE_CHECKING:
14
+ from tau.tool.types import Tool
15
+
16
+
17
+ class AgentPhase(str, Enum):
18
+ """Agent execution phase."""
19
+ IDLE = "idle"
20
+ TURN = "turn"
21
+
22
+
23
+ @dataclass
24
+ class AgentContext:
25
+ """Snapshot of everything the LLM receives for one turn."""
26
+ system_prompt: str
27
+ messages: list[LLMMessage]
28
+ tools: list[Tool] = field(default_factory=list)
29
+
30
+
31
+ class AgentConfig(BaseModel):
32
+ """Internal runtime config passed to Agent.__init__."""
33
+ model_config = {'arbitrary_types_allowed': True}
34
+
35
+ cwd: Path
36
+ system_prompt: str = ""
37
+ model: Any | None = None
38
+ context_window: int = 200_000
39
+ compaction: CompactionSettings = DEFAULT_COMPACTION_SETTINGS
40
+
41
+
42
+
43
+
44
+ class PromptOptions(BaseModel):
45
+ """Configuration options for prompt submission."""
46
+ model_config = {'arbitrary_types_allowed': True}
47
+
48
+ meta: MessageMeta | None = None
49
+ images: list[bytes] = []
50
+ audio: list[bytes] = []
51
+ video: list[bytes] = []
52
+
53
+
54
+ @dataclass
55
+ class ContextUsage:
56
+ """Token usage and context window statistics."""
57
+ tokens: int
58
+ context_window: int
59
+ percent: float | None = None
60
+
61
+
tau/auth/manager.py ADDED
@@ -0,0 +1,247 @@
1
+ from __future__ import annotations
2
+ import os
3
+ import json
4
+ from pathlib import Path
5
+ from typing import List
6
+ from tau.inference.provider.registry import ProviderRegistry
7
+ from tau.inference.provider.oauth import OAuthLoginCallbacks
8
+ from tau.settings.paths import get_auth_path
9
+ from tau.auth.types import AuthCredential, AuthStatus, OAuthCredential, APICredential, AuthType, LockResult
10
+ from tau.auth.storage import AuthStorage, FileAuthStorage, InMemoryAuthStorage
11
+ from tau.utils.secrets import resolve_secret
12
+
13
+
14
+ def _get_env_api_key(provider: str) -> str | None:
15
+ """Get API key for a provider from environment variables."""
16
+ return os.environ.get(f"{provider.upper()}_API_KEY")
17
+
18
+
19
+ class AuthManager:
20
+ """Credential storage with pluggable backends."""
21
+
22
+ def __init__(self, registry: ProviderRegistry, storage: AuthStorage):
23
+ self.registry = registry
24
+ self.storage = storage
25
+ self.runtime_overrides: dict[str, str] = {}
26
+ self._load_error: Exception | None = None
27
+ self._errors: list[Exception] = []
28
+ self.data: dict[str, AuthCredential] = self._load()
29
+
30
+ @staticmethod
31
+ def create(registry: ProviderRegistry, auth_path: Path | None = None) -> "AuthManager":
32
+ """Create AuthManager with file storage."""
33
+ path = auth_path or get_auth_path()
34
+ storage = FileAuthStorage(path)
35
+ return AuthManager(registry, storage)
36
+
37
+ @staticmethod
38
+ def from_storage(registry: ProviderRegistry, storage: AuthStorage) -> "AuthManager":
39
+ """Create AuthManager with custom storage."""
40
+ return AuthManager(registry, storage)
41
+
42
+ @staticmethod
43
+ def in_memory(registry: ProviderRegistry, initial: dict = {}) -> "AuthManager":
44
+ """Create AuthManager with in-memory storage for testing."""
45
+ storage = InMemoryAuthStorage()
46
+ storage.with_lock(lambda _: LockResult(result=None, next=json.dumps(initial, indent=2)))
47
+ return AuthManager.from_storage(registry, storage)
48
+
49
+ def _record_error(self, error: Exception) -> None:
50
+ """Record an error for later retrieval."""
51
+ self._errors.append(error)
52
+
53
+ def _parse_storage_data(self, content: str | None) -> dict[str, AuthCredential]:
54
+ """Parse credential data from storage JSON."""
55
+ if not content:
56
+ return {}
57
+ raw_data = json.loads(content)
58
+ data: dict[str, AuthCredential] = {}
59
+ for k, v in raw_data.items():
60
+ cred_type = v.get("type")
61
+ match cred_type:
62
+ case AuthType.OAuth:
63
+ raw_extra = v.get("extra") or {}
64
+ extra = {str(ek): str(ev) for ek, ev in raw_extra.items()} if isinstance(raw_extra, dict) else {}
65
+ data[k] = OAuthCredential(
66
+ access=v.get("access", ""),
67
+ refresh=v.get("refresh", ""),
68
+ expires=v.get("expires", 0),
69
+ extra=extra,
70
+ )
71
+ case AuthType.ApiKey:
72
+ data[k] = APICredential(key=v.get("key", ""))
73
+ return data
74
+
75
+ def _load(self) -> dict[str, AuthCredential]:
76
+ """Load credentials from storage."""
77
+ try:
78
+ result = self.storage.with_lock(lambda current: LockResult(result=current))
79
+ self._load_error = None
80
+ return self._parse_storage_data(result.result)
81
+ except Exception as e:
82
+ self._load_error = e
83
+ self._record_error(e)
84
+ return {}
85
+
86
+ @staticmethod
87
+ def _serialize_credential(credential: AuthCredential) -> dict:
88
+ """Serialize a credential to storable dict format."""
89
+ if isinstance(credential, OAuthCredential):
90
+ return {
91
+ "type": AuthType.OAuth,
92
+ "access": credential.access,
93
+ "refresh": credential.refresh,
94
+ "expires": credential.expires,
95
+ "extra": dict(credential.extra),
96
+ }
97
+ return {"type": AuthType.ApiKey, "key": credential.key}
98
+
99
+ def _persist_provider_change(self, provider: str, credential: AuthCredential | None) -> None:
100
+ """Persist a credential change to storage."""
101
+ if self._load_error:
102
+ return
103
+
104
+ def update_fn(current: str | None) -> LockResult:
105
+ """Update storage data with new credential."""
106
+ current_data = self._parse_storage_data(current)
107
+ merged = {k: self._serialize_credential(v) for k, v in current_data.items()}
108
+ if credential:
109
+ merged[provider] = self._serialize_credential(credential)
110
+ else:
111
+ merged.pop(provider, None)
112
+ return LockResult(result=None, next=json.dumps(merged, indent=2))
113
+
114
+ try:
115
+ self.storage.with_lock(update_fn)
116
+ except Exception as e:
117
+ self._record_error(e)
118
+
119
+ def reload(self) -> None:
120
+ """Reload credentials from storage."""
121
+ self.data = self._load()
122
+
123
+ def get(self, provider: str) -> AuthCredential | None:
124
+ """Return the stored credential for a provider, or None if not found."""
125
+ return self.data.get(provider)
126
+
127
+ def has(self, provider: str) -> bool:
128
+ """Check if credentials exist for a provider in storage."""
129
+ return provider in self.data
130
+
131
+ def list(self) -> list[str]:
132
+ """List all providers with stored credentials."""
133
+ return list(self.data.keys())
134
+
135
+ def set(self, provider: str, credential: AuthCredential) -> None:
136
+ """Store a credential for a provider and persist to storage."""
137
+ self.data[provider] = credential
138
+ self._persist_provider_change(provider=provider, credential=credential)
139
+
140
+ def remove(self, provider: str) -> None:
141
+ """Remove the stored credential for a provider and persist to storage."""
142
+ self.data.pop(provider, None)
143
+ self._persist_provider_change(provider=provider, credential=None)
144
+
145
+ def set_runtime_api_key(self, provider: str, api_key: str) -> None:
146
+ """Set a runtime API key override (not persisted)."""
147
+ self.runtime_overrides[provider] = api_key
148
+
149
+ def remove_runtime_api_key(self, provider: str) -> None:
150
+ """Remove a runtime API key override."""
151
+ self.runtime_overrides.pop(provider, None)
152
+
153
+ def get_auth_status(self, provider: str) -> AuthStatus:
154
+ """Return auth status without exposing credential values."""
155
+ if self.has(provider):
156
+ return AuthStatus(configured=True, source="stored")
157
+ if provider in self.runtime_overrides:
158
+ return AuthStatus(configured=True, source="runtime", label="--api-key")
159
+ env_key = f"{provider.upper()}_API_KEY"
160
+ if os.environ.get(env_key):
161
+ return AuthStatus(configured=True, source="env", label=env_key)
162
+ return AuthStatus(configured=False)
163
+
164
+ def drain_errors(self) -> List[Exception]:
165
+ """Return and clear accumulated errors."""
166
+ drained = list(self._errors)
167
+ self._errors.clear()
168
+ return drained
169
+
170
+ async def get_api_key(self, provider: str) -> str | None:
171
+ """Get an API key for a provider, refreshing OAuth tokens if needed."""
172
+ # 1. Runtime override
173
+ if provider in self.runtime_overrides:
174
+ return resolve_secret(self.runtime_overrides[provider])
175
+
176
+ credential = self.get(provider)
177
+
178
+ match credential:
179
+ case APICredential():
180
+ # The stored key may be a literal, "$ENV_VAR", or "!command";
181
+ # resolved once and cached (see tau.utils.secrets).
182
+ return resolve_secret(credential.key)
183
+ case OAuthCredential():
184
+ oauth_provider = self.registry.text.get_oauth_provider(provider=provider)
185
+ if not oauth_provider:
186
+ return None
187
+
188
+ if oauth_provider.is_expired(credential=credential):
189
+ refreshed_credential = await self._refresh_oauth_token_with_lock(provider=provider)
190
+ if refreshed_credential:
191
+ credential = refreshed_credential
192
+ else:
193
+ return None
194
+ return oauth_provider.get_api_key(credential=credential)
195
+
196
+ # 2. Environment variable fallback
197
+ return _get_env_api_key(provider)
198
+
199
+ async def _refresh_oauth_token_with_lock(self, provider: str) -> OAuthCredential | None:
200
+ """Refresh an expired OAuth token with file locking to prevent race conditions."""
201
+ oauth_provider = self.registry.text.get_oauth_provider(provider=provider)
202
+ if not oauth_provider:
203
+ return None
204
+
205
+ async def refresh_fn(current: str | None) -> LockResult:
206
+ """Refresh OAuth token in storage."""
207
+ current_data = self._parse_storage_data(current)
208
+ credential = current_data.get(provider)
209
+
210
+ if not isinstance(credential, OAuthCredential):
211
+ return LockResult(result=None)
212
+
213
+ # Check if another instance already refreshed
214
+ if not oauth_provider.is_expired(credential=credential):
215
+ return LockResult(result=credential)
216
+
217
+ try:
218
+ refreshed_credential = await oauth_provider.refresh_token(credential=credential)
219
+ if credential.extra:
220
+ merged_extra = dict(credential.extra)
221
+ merged_extra.update(refreshed_credential.extra)
222
+ refreshed_credential.extra = merged_extra
223
+ current_data[provider] = refreshed_credential
224
+ self.data = current_data
225
+ serialized = {k: self._serialize_credential(v) for k, v in current_data.items()}
226
+ return LockResult(result=refreshed_credential, next=json.dumps(serialized, indent=2))
227
+ except Exception as e:
228
+ self._record_error(e)
229
+ return LockResult(result=None)
230
+
231
+ result = await self.storage.with_lock_async(refresh_fn)
232
+ return result.result
233
+
234
+ async def login(self, provider: str, callbacks: OAuthLoginCallbacks):
235
+ """Perform OAuth login flow for a provider and store the resulting credential."""
236
+ if oauth_provider := self.registry.text.get_oauth_provider(provider):
237
+ credential = await oauth_provider.login(callbacks=callbacks)
238
+ self.data[provider] = credential
239
+ self._persist_provider_change(provider, credential)
240
+
241
+ async def logout(self, provider: str):
242
+ """Perform OAuth logout for a provider and remove the stored credential."""
243
+ if oauth_provider := self.registry.text.get_oauth_provider(provider):
244
+ if credential := self.get(provider):
245
+ if isinstance(credential, OAuthCredential):
246
+ await oauth_provider.logout(credential=credential)
247
+ self.remove(provider)