flowent 0.3.2 → 0.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -0
- package/backend/README.md +3 -0
- package/backend/pyproject.toml +1 -1
- package/backend/src/flowent/agent.py +61 -2
- package/backend/src/flowent/agent_runtime.py +220 -0
- package/backend/src/flowent/api_models.py +9 -1
- package/backend/src/flowent/app.py +9 -2
- package/backend/src/flowent/permissions.py +12 -3
- package/backend/src/flowent/routes/workflow_routes.py +22 -42
- package/backend/src/flowent/sandbox.py +63 -19
- package/backend/src/flowent/state/models.py +1 -1
- package/backend/src/flowent/static/assets/index-ByGH1ZWH.css +2 -0
- package/backend/src/flowent/static/assets/index-D3WSbctU.js +98 -0
- package/backend/src/flowent/static/index.html +2 -2
- package/backend/src/flowent/tools.py +60 -4
- package/backend/src/flowent/workflow_service.py +115 -0
- package/backend/src/flowent/workflow_tools.py +279 -0
- package/backend/src/flowent/workflows.py +182 -19
- package/backend/src/flowent/workspace/runtime.py +81 -93
- package/backend/uv.lock +1 -1
- package/dist/frontend/assets/index-ByGH1ZWH.css +2 -0
- package/dist/frontend/assets/index-D3WSbctU.js +98 -0
- package/dist/frontend/index.html +2 -2
- package/package.json +1 -1
- package/backend/src/flowent/static/assets/index-BX18a4Jz.js +0 -100
- package/backend/src/flowent/static/assets/index-EC37agAH.css +0 -2
- package/dist/frontend/assets/index-BX18a4Jz.js +0 -100
- package/dist/frontend/assets/index-EC37agAH.css +0 -2
|
@@ -1,16 +1,19 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import json
|
|
2
4
|
import re
|
|
5
|
+
import sys
|
|
6
|
+
import tempfile
|
|
3
7
|
from collections import defaultdict, deque
|
|
4
8
|
from collections.abc import Mapping
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import TYPE_CHECKING
|
|
5
11
|
|
|
6
12
|
from pydantic import BaseModel, ConfigDict, Field
|
|
7
13
|
|
|
8
|
-
from flowent.
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
ProviderConnection,
|
|
12
|
-
complete_chat,
|
|
13
|
-
)
|
|
14
|
+
from flowent.context import runtime_context_messages
|
|
15
|
+
from flowent.llm import ProviderConnection
|
|
16
|
+
from flowent.sandbox import SandboxRunner
|
|
14
17
|
from flowent.storage import (
|
|
15
18
|
StoredWorkflow,
|
|
16
19
|
StoredWorkflowDefinition,
|
|
@@ -18,6 +21,9 @@ from flowent.storage import (
|
|
|
18
21
|
StoredWorkflowNode,
|
|
19
22
|
)
|
|
20
23
|
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from flowent.agent_runtime import FlowentAgentRuntime
|
|
26
|
+
|
|
21
27
|
|
|
22
28
|
class WorkflowNodeRunResult(BaseModel):
|
|
23
29
|
model_config = ConfigDict(extra="forbid")
|
|
@@ -37,7 +43,39 @@ class WorkflowRunResponse(BaseModel):
|
|
|
37
43
|
workflow_id: str
|
|
38
44
|
|
|
39
45
|
|
|
46
|
+
class WorkflowRunRequestValues(BaseModel):
|
|
47
|
+
model_config = ConfigDict(extra="forbid")
|
|
48
|
+
|
|
49
|
+
default_input: str = ""
|
|
50
|
+
input_values: dict[str, str] = Field(default_factory=dict)
|
|
51
|
+
|
|
52
|
+
|
|
40
53
|
PLACEHOLDER_PATTERN = re.compile(r"\{\{\s*([A-Za-z0-9_.-]+)\.output\s*\}\}")
|
|
54
|
+
PYTHON_CODE_RUNNER = r"""
|
|
55
|
+
import contextlib
|
|
56
|
+
import io
|
|
57
|
+
import json
|
|
58
|
+
import sys
|
|
59
|
+
|
|
60
|
+
payload = json.loads(sys.stdin.read() or "{}")
|
|
61
|
+
namespace = {
|
|
62
|
+
"input": payload.get("input", ""),
|
|
63
|
+
"inputs": payload.get("inputs", []),
|
|
64
|
+
"output": "",
|
|
65
|
+
}
|
|
66
|
+
stdout = io.StringIO()
|
|
67
|
+
with contextlib.redirect_stdout(stdout):
|
|
68
|
+
exec(str(payload.get("code", "")), namespace)
|
|
69
|
+
captured = stdout.getvalue()
|
|
70
|
+
result = namespace.get("output")
|
|
71
|
+
if result is None:
|
|
72
|
+
result = ""
|
|
73
|
+
if result == "" and captured:
|
|
74
|
+
result = captured.rstrip("\n")
|
|
75
|
+
if not isinstance(result, str):
|
|
76
|
+
result = json.dumps(result, ensure_ascii=False)
|
|
77
|
+
print(result, end="")
|
|
78
|
+
"""
|
|
41
79
|
|
|
42
80
|
|
|
43
81
|
def validate_workflow(workflow: StoredWorkflow) -> StoredWorkflow:
|
|
@@ -77,8 +115,8 @@ def validate_workflow_definition(definition: StoredWorkflowDefinition) -> list[s
|
|
|
77
115
|
raise ValueError("Workflow node ids must not be empty.")
|
|
78
116
|
if len(set(node_ids)) != len(node_ids):
|
|
79
117
|
raise ValueError("Workflow node ids must be unique.")
|
|
80
|
-
if not any(node.type
|
|
81
|
-
raise ValueError("Workflow needs an input node.")
|
|
118
|
+
if not any(node.type in {"input", "timer"} for node in definition.nodes):
|
|
119
|
+
raise ValueError("Workflow needs an input or timer node.")
|
|
82
120
|
if not any(node.type == "output" for node in definition.nodes):
|
|
83
121
|
raise ValueError("Workflow needs an output node.")
|
|
84
122
|
|
|
@@ -100,6 +138,43 @@ def workflow_requires_connection(definition: StoredWorkflowDefinition) -> bool:
|
|
|
100
138
|
return any(node.type == "agent" for node in definition.nodes)
|
|
101
139
|
|
|
102
140
|
|
|
141
|
+
def timer_run_node_ids(
|
|
142
|
+
definition: StoredWorkflowDefinition, timer_node_id: str
|
|
143
|
+
) -> set[str]:
|
|
144
|
+
nodes = {node.id: node for node in definition.nodes}
|
|
145
|
+
timer_node = nodes.get(timer_node_id)
|
|
146
|
+
if timer_node is None or timer_node.type != "timer":
|
|
147
|
+
raise ValueError("Timer node not found.")
|
|
148
|
+
|
|
149
|
+
outgoing: dict[str, list[str]] = defaultdict(list)
|
|
150
|
+
incoming: dict[str, list[str]] = defaultdict(list)
|
|
151
|
+
for edge in definition.edges:
|
|
152
|
+
outgoing[edge.source].append(edge.target)
|
|
153
|
+
incoming[edge.target].append(edge.source)
|
|
154
|
+
|
|
155
|
+
active = {timer_node_id}
|
|
156
|
+
queue = deque([timer_node_id])
|
|
157
|
+
while queue:
|
|
158
|
+
node_id = queue.popleft()
|
|
159
|
+
for target in outgoing[node_id]:
|
|
160
|
+
if target not in active:
|
|
161
|
+
active.add(target)
|
|
162
|
+
queue.append(target)
|
|
163
|
+
|
|
164
|
+
queue = deque(active)
|
|
165
|
+
while queue:
|
|
166
|
+
node_id = queue.popleft()
|
|
167
|
+
for source in incoming[node_id]:
|
|
168
|
+
source_node = nodes[source]
|
|
169
|
+
if source_node.type == "timer" and source != timer_node_id:
|
|
170
|
+
continue
|
|
171
|
+
if source not in active:
|
|
172
|
+
active.add(source)
|
|
173
|
+
queue.append(source)
|
|
174
|
+
|
|
175
|
+
return active
|
|
176
|
+
|
|
177
|
+
|
|
103
178
|
def topological_node_ids(definition: StoredWorkflowDefinition) -> list[str]:
|
|
104
179
|
node_ids = [node.id for node in definition.nodes]
|
|
105
180
|
outgoing: dict[str, list[str]] = defaultdict(list)
|
|
@@ -133,35 +208,76 @@ def topological_node_ids(definition: StoredWorkflowDefinition) -> list[str]:
|
|
|
133
208
|
|
|
134
209
|
async def run_workflow_definition(
|
|
135
210
|
*,
|
|
136
|
-
completion: CompletionCallable | None,
|
|
137
211
|
connection: ProviderConnection | None,
|
|
138
212
|
definition: StoredWorkflowDefinition,
|
|
213
|
+
default_input: str = "",
|
|
214
|
+
input_values: Mapping[str, str] | None = None,
|
|
215
|
+
runtime: FlowentAgentRuntime | None = None,
|
|
216
|
+
timer_node_id: str = "",
|
|
217
|
+
workflow_depth: int = 0,
|
|
139
218
|
workflow_id: str,
|
|
140
219
|
) -> WorkflowRunResponse:
|
|
141
220
|
ordered_ids = validate_workflow_definition(definition)
|
|
142
221
|
if workflow_requires_connection(definition) and connection is None:
|
|
143
222
|
raise ValueError("Choose a provider and model before running.")
|
|
144
223
|
|
|
224
|
+
return await run_workflow_once(
|
|
225
|
+
connection=connection,
|
|
226
|
+
definition=definition,
|
|
227
|
+
input_values=WorkflowRunRequestValues(
|
|
228
|
+
default_input=default_input,
|
|
229
|
+
input_values=dict(input_values or {}),
|
|
230
|
+
),
|
|
231
|
+
ordered_ids=ordered_ids,
|
|
232
|
+
runtime=runtime,
|
|
233
|
+
timer_node_id=timer_node_id,
|
|
234
|
+
workflow_depth=workflow_depth,
|
|
235
|
+
workflow_id=workflow_id,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
async def run_workflow_once(
|
|
240
|
+
*,
|
|
241
|
+
connection: ProviderConnection | None,
|
|
242
|
+
definition: StoredWorkflowDefinition,
|
|
243
|
+
input_values: WorkflowRunRequestValues,
|
|
244
|
+
ordered_ids: list[str],
|
|
245
|
+
runtime: FlowentAgentRuntime | None = None,
|
|
246
|
+
timer_node_id: str = "",
|
|
247
|
+
workflow_depth: int = 0,
|
|
248
|
+
workflow_id: str,
|
|
249
|
+
) -> WorkflowRunResponse:
|
|
145
250
|
nodes = {node.id: node for node in definition.nodes}
|
|
146
251
|
incoming_edges = edges_by_target(definition.edges)
|
|
252
|
+
active_node_ids = (
|
|
253
|
+
timer_run_node_ids(definition, timer_node_id) if timer_node_id else None
|
|
254
|
+
)
|
|
147
255
|
results: dict[str, WorkflowNodeRunResult] = {
|
|
148
256
|
node.id: WorkflowNodeRunResult(id=node.id, status="pending")
|
|
149
257
|
for node in definition.nodes
|
|
150
258
|
}
|
|
151
259
|
outputs: dict[str, str] = {}
|
|
152
260
|
named_outputs: dict[str, str] = {}
|
|
261
|
+
remaining_default_input = input_values.default_input
|
|
153
262
|
|
|
154
263
|
for node_id in ordered_ids:
|
|
155
264
|
node = nodes[node_id]
|
|
265
|
+
if active_node_ids is not None and node.id not in active_node_ids:
|
|
266
|
+
continue
|
|
156
267
|
results[node.id] = WorkflowNodeRunResult(id=node.id, status="running")
|
|
157
268
|
try:
|
|
158
269
|
output = await run_node(
|
|
159
|
-
completion=completion,
|
|
160
270
|
connection=connection,
|
|
271
|
+
default_input=remaining_default_input,
|
|
272
|
+
input_values=input_values.input_values,
|
|
161
273
|
incoming_edges=incoming_edges[node.id],
|
|
162
274
|
node=node,
|
|
163
275
|
outputs=outputs,
|
|
276
|
+
runtime=runtime,
|
|
277
|
+
workflow_depth=workflow_depth,
|
|
164
278
|
)
|
|
279
|
+
if node.type == "input" and remaining_default_input:
|
|
280
|
+
remaining_default_input = ""
|
|
165
281
|
except Exception as error:
|
|
166
282
|
results[node.id] = WorkflowNodeRunResult(
|
|
167
283
|
error=str(error) or "Node could not be completed.",
|
|
@@ -193,38 +309,79 @@ async def run_workflow_definition(
|
|
|
193
309
|
|
|
194
310
|
async def run_node(
|
|
195
311
|
*,
|
|
196
|
-
completion: CompletionCallable | None,
|
|
197
312
|
connection: ProviderConnection | None,
|
|
313
|
+
default_input: str,
|
|
314
|
+
input_values: Mapping[str, str],
|
|
198
315
|
incoming_edges: list[StoredWorkflowEdge],
|
|
199
316
|
node: StoredWorkflowNode,
|
|
200
317
|
outputs: Mapping[str, str],
|
|
318
|
+
runtime: FlowentAgentRuntime | None = None,
|
|
319
|
+
workflow_depth: int = 0,
|
|
201
320
|
) -> str:
|
|
202
321
|
if node.type == "input":
|
|
322
|
+
if node.id in input_values:
|
|
323
|
+
return input_values[node.id]
|
|
324
|
+
if default_input:
|
|
325
|
+
return default_input
|
|
203
326
|
return node_data_text(node, "default_value")
|
|
204
327
|
if node.type == "agent":
|
|
205
328
|
if connection is None:
|
|
206
329
|
raise ValueError("Choose a provider and model before running.")
|
|
330
|
+
if runtime is None:
|
|
331
|
+
raise ValueError("Agent runtime is not available.")
|
|
207
332
|
prompt = render_template(
|
|
208
333
|
node_data_text(node, "prompt")
|
|
209
334
|
or joined_upstream_outputs(incoming_edges, outputs),
|
|
210
335
|
outputs,
|
|
211
336
|
)
|
|
212
|
-
|
|
213
|
-
connection,
|
|
214
|
-
[
|
|
215
|
-
|
|
337
|
+
result = await runtime.complete(
|
|
338
|
+
connection=connection,
|
|
339
|
+
messages=[
|
|
340
|
+
*runtime_context_messages(
|
|
341
|
+
runtime.cwd, runtime.store.read_state().settings.agent_prompt
|
|
342
|
+
),
|
|
343
|
+
{"role": "user", "content": prompt},
|
|
344
|
+
],
|
|
345
|
+
user_request=prompt,
|
|
346
|
+
workflow_depth=workflow_depth,
|
|
216
347
|
)
|
|
217
|
-
return
|
|
348
|
+
return result.content
|
|
218
349
|
if node.type == "merge":
|
|
219
350
|
upstream = upstream_outputs(incoming_edges, outputs)
|
|
220
351
|
if node_data_text(node, "merge_strategy") == "json":
|
|
221
352
|
return merge_json_outputs(upstream)
|
|
222
353
|
return "\n".join(output for output in upstream if output)
|
|
354
|
+
if node.type == "code":
|
|
355
|
+
return await run_code_node(node, upstream_outputs(incoming_edges, outputs))
|
|
356
|
+
if node.type == "timer":
|
|
357
|
+
return timer_payload(node)
|
|
223
358
|
if node.type == "output":
|
|
224
359
|
return joined_upstream_outputs(incoming_edges, outputs)
|
|
225
360
|
raise ValueError("Node type is not supported.")
|
|
226
361
|
|
|
227
362
|
|
|
363
|
+
async def run_code_node(node: StoredWorkflowNode, upstream: list[str]) -> str:
|
|
364
|
+
code = node_data_text(node, "code")
|
|
365
|
+
if not code.strip():
|
|
366
|
+
return joined_text(upstream)
|
|
367
|
+
with tempfile.TemporaryDirectory(prefix="flowent-workflow-code-") as code_dir:
|
|
368
|
+
result = await SandboxRunner(timeout_seconds=10, cwd=Path(code_dir)).run_async(
|
|
369
|
+
[sys.executable, "-I", "-c", PYTHON_CODE_RUNNER],
|
|
370
|
+
input_text=json.dumps(
|
|
371
|
+
{
|
|
372
|
+
"code": code,
|
|
373
|
+
"input": joined_text(upstream),
|
|
374
|
+
"inputs": upstream,
|
|
375
|
+
},
|
|
376
|
+
ensure_ascii=False,
|
|
377
|
+
),
|
|
378
|
+
timeout_seconds=10,
|
|
379
|
+
)
|
|
380
|
+
if result.exit_code != 0:
|
|
381
|
+
raise ValueError((result.stderr or result.stdout).strip() or "Code failed.")
|
|
382
|
+
return result.stdout
|
|
383
|
+
|
|
384
|
+
|
|
228
385
|
def edges_by_target(
|
|
229
386
|
edges: list[StoredWorkflowEdge],
|
|
230
387
|
) -> dict[str, list[StoredWorkflowEdge]]:
|
|
@@ -247,6 +404,10 @@ def node_output_key(node: StoredWorkflowNode) -> str:
|
|
|
247
404
|
return node_data_text(node, "output_key") or node.id
|
|
248
405
|
|
|
249
406
|
|
|
407
|
+
def timer_payload(node: StoredWorkflowNode) -> str:
|
|
408
|
+
return node_data_text(node, "payload") or "Timer fired."
|
|
409
|
+
|
|
410
|
+
|
|
250
411
|
def upstream_outputs(
|
|
251
412
|
incoming_edges: list[StoredWorkflowEdge],
|
|
252
413
|
outputs: Mapping[str, str],
|
|
@@ -258,9 +419,11 @@ def joined_upstream_outputs(
|
|
|
258
419
|
incoming_edges: list[StoredWorkflowEdge],
|
|
259
420
|
outputs: Mapping[str, str],
|
|
260
421
|
) -> str:
|
|
261
|
-
return
|
|
262
|
-
|
|
263
|
-
|
|
422
|
+
return joined_text(upstream_outputs(incoming_edges, outputs))
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def joined_text(values: list[str]) -> str:
|
|
426
|
+
return "\n".join(value for value in values if value)
|
|
264
427
|
|
|
265
428
|
|
|
266
429
|
def render_template(template: str, outputs: Mapping[str, str]) -> str:
|
|
@@ -10,14 +10,13 @@ from uuid import uuid4
|
|
|
10
10
|
|
|
11
11
|
from fastapi import HTTPException
|
|
12
12
|
|
|
13
|
-
from flowent.agent import AgentContextUpdate
|
|
14
|
-
from flowent.
|
|
13
|
+
from flowent.agent import AgentContextUpdate
|
|
14
|
+
from flowent.agent_runtime import FlowentAgentRuntime
|
|
15
15
|
from flowent.compact import CompactInput, CompactProvider
|
|
16
16
|
from flowent.context import runtime_context_messages
|
|
17
17
|
from flowent.llm import ChatMessage, CompletionCallable, ProviderConnection
|
|
18
18
|
from flowent.logging import TRACE_LEVEL
|
|
19
19
|
from flowent.mcp import McpManager
|
|
20
|
-
from flowent.permissions import run_tool_with_path_permissions
|
|
21
20
|
from flowent.provider_connections import selected_connection
|
|
22
21
|
from flowent.skills import explicit_skill_messages
|
|
23
22
|
from flowent.storage import (
|
|
@@ -27,7 +26,7 @@ from flowent.storage import (
|
|
|
27
26
|
StoredState,
|
|
28
27
|
StoredToolItem,
|
|
29
28
|
)
|
|
30
|
-
from flowent.tools import
|
|
29
|
+
from flowent.tools import text_tool_result
|
|
31
30
|
from flowent.usage import (
|
|
32
31
|
TokenUsage,
|
|
33
32
|
TokenUsageInfo,
|
|
@@ -36,6 +35,7 @@ from flowent.usage import (
|
|
|
36
35
|
is_context_window_error,
|
|
37
36
|
recompute_context_usage,
|
|
38
37
|
)
|
|
38
|
+
from flowent.workflow_service import WorkflowService
|
|
39
39
|
from flowent.workspace.context import (
|
|
40
40
|
COMPACTED_CONTEXT_MARKER,
|
|
41
41
|
OPTIMIZED_CONTEXT_MARKER,
|
|
@@ -86,16 +86,33 @@ class WorkspaceRuntime:
|
|
|
86
86
|
cwd: Path,
|
|
87
87
|
mcp_manager: McpManager,
|
|
88
88
|
store: StateStore,
|
|
89
|
+
workflow_service: WorkflowService,
|
|
89
90
|
) -> None:
|
|
90
91
|
self.chat_completion = chat_completion
|
|
91
92
|
self.compact_provider = compact_provider
|
|
92
93
|
self.cwd = cwd
|
|
93
|
-
self.mcp_manager = mcp_manager
|
|
94
94
|
self.store = store
|
|
95
|
+
self.workflow_service = workflow_service
|
|
96
|
+
self.agent_runtime = FlowentAgentRuntime(
|
|
97
|
+
chat_completion=chat_completion,
|
|
98
|
+
cwd=cwd,
|
|
99
|
+
mcp_manager=mcp_manager,
|
|
100
|
+
store=store,
|
|
101
|
+
workflow_service=workflow_service,
|
|
102
|
+
)
|
|
95
103
|
self.active_response: WorkspaceResponse | None = None
|
|
96
104
|
self.generation = 0
|
|
97
105
|
self.active_compact_task: WorkspaceCompactTask | None = None
|
|
98
106
|
|
|
107
|
+
def extra_tool_specs(self) -> list[Mapping[str, object]]:
|
|
108
|
+
return self.agent_runtime.extra_tool_specs()
|
|
109
|
+
|
|
110
|
+
def model_tool_specs(self) -> list[Mapping[str, object]]:
|
|
111
|
+
return self.agent_runtime.model_tool_specs()
|
|
112
|
+
|
|
113
|
+
def extra_tool_title(self, name: str) -> str | None:
|
|
114
|
+
return self.agent_runtime.extra_tool_title(name)
|
|
115
|
+
|
|
99
116
|
def request_messages_for_content(
|
|
100
117
|
self,
|
|
101
118
|
state: StoredState,
|
|
@@ -229,10 +246,7 @@ class WorkspaceRuntime:
|
|
|
229
246
|
)
|
|
230
247
|
next_messages = [*state.messages, user_message]
|
|
231
248
|
self.store.save_messages(next_messages)
|
|
232
|
-
model_tool_specs =
|
|
233
|
-
*tool_specs(),
|
|
234
|
-
*list(self.mcp_manager.tool_specs()),
|
|
235
|
-
]
|
|
249
|
+
model_tool_specs = self.model_tool_specs()
|
|
236
250
|
model_history: list[ChatMessage | Mapping[str, object]] = [
|
|
237
251
|
*runtime_context_messages(self.cwd, state.settings.agent_prompt),
|
|
238
252
|
*workspace_chat_messages(
|
|
@@ -265,42 +279,11 @@ class WorkspaceRuntime:
|
|
|
265
279
|
current_output_index = 0
|
|
266
280
|
latest_usage_output_index: int | None = None
|
|
267
281
|
|
|
268
|
-
async
|
|
269
|
-
|
|
270
|
-
connection,
|
|
271
|
-
request.model_copy(
|
|
272
|
-
update={
|
|
273
|
-
"transcript": approval_transcript(next_messages),
|
|
274
|
-
"user_request": content,
|
|
275
|
-
}
|
|
276
|
-
),
|
|
277
|
-
completion=self.chat_completion,
|
|
278
|
-
)
|
|
279
|
-
|
|
280
|
-
async def tool_runner(
|
|
281
|
-
name: str,
|
|
282
|
-
arguments: dict[str, object],
|
|
283
|
-
context: ToolContext,
|
|
284
|
-
):
|
|
285
|
-
return await run_tool_with_path_permissions(
|
|
286
|
-
name,
|
|
287
|
-
arguments,
|
|
288
|
-
context,
|
|
289
|
-
review_approval=review_tool_approval,
|
|
290
|
-
writable_paths=[
|
|
291
|
-
Path(path.path) for path in self.store.read_writable_paths()
|
|
292
|
-
],
|
|
293
|
-
)
|
|
294
|
-
|
|
295
|
-
async for event in run_agent_stream(
|
|
296
|
-
completion=self.chat_completion,
|
|
282
|
+
async for event in self.agent_runtime.stream(
|
|
283
|
+
approval_transcript=approval_transcript(next_messages),
|
|
297
284
|
connection=connection,
|
|
298
|
-
cwd=self.cwd,
|
|
299
|
-
extra_tool_runner=self.mcp_manager.run_tool,
|
|
300
|
-
extra_tool_specs=self.mcp_manager.tool_specs(),
|
|
301
|
-
extra_tool_title=self.mcp_manager.tool_title,
|
|
302
285
|
messages=request_messages,
|
|
303
|
-
|
|
286
|
+
user_request=content,
|
|
304
287
|
):
|
|
305
288
|
if event.event == "start":
|
|
306
289
|
event_id = event.data.get("id")
|
|
@@ -617,16 +600,37 @@ class WorkspaceRuntime:
|
|
|
617
600
|
*base_request_messages,
|
|
618
601
|
*model_visible_assistant_output_messages(trimmed_message),
|
|
619
602
|
]
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
603
|
+
try:
|
|
604
|
+
response = self._start_response_from_messages(
|
|
605
|
+
content=previous_user_message.content,
|
|
606
|
+
initial_assistant_message=trimmed_message,
|
|
607
|
+
next_messages=next_messages,
|
|
608
|
+
output_start_index=assistant_retry_output_start_index(trimmed_message),
|
|
609
|
+
request_messages=request_messages,
|
|
610
|
+
state=state_before_assistant,
|
|
611
|
+
usage_request_messages=base_request_messages,
|
|
612
|
+
user_message=previous_user_message,
|
|
613
|
+
)
|
|
614
|
+
except HTTPException as error:
|
|
615
|
+
error_detail = str(error.detail or "")
|
|
616
|
+
assistant_output = AssistantOutputBuilder.from_message(trimmed_message)
|
|
617
|
+
assistant_output.append_error(
|
|
618
|
+
run_error_output_item(trimmed_message.id, error_detail).model_copy(
|
|
619
|
+
update={"id": error_id}
|
|
620
|
+
)
|
|
621
|
+
)
|
|
622
|
+
failed_message = StoredMessage(
|
|
623
|
+
author="assistant",
|
|
624
|
+
content=assistant_output.content,
|
|
625
|
+
groups=assistant_output.groups,
|
|
626
|
+
id=trimmed_message.id,
|
|
627
|
+
status="failed",
|
|
628
|
+
thinking=assistant_output.thinking,
|
|
629
|
+
tools=list(assistant_output.tools.values()),
|
|
630
|
+
usage_info=self.store.read_usage_info(),
|
|
631
|
+
)
|
|
632
|
+
self.store.save_messages([*previous_messages, failed_message])
|
|
633
|
+
raise
|
|
630
634
|
return next_messages, response
|
|
631
635
|
|
|
632
636
|
def _start_response_from_messages(
|
|
@@ -706,11 +710,14 @@ class WorkspaceRuntime:
|
|
|
706
710
|
def refresh_assistant(status: str = "running") -> StoredMessage | None:
|
|
707
711
|
return update_assistant_message(status, persist=False)
|
|
708
712
|
|
|
709
|
-
def persist_assistant_progress(
|
|
713
|
+
def persist_assistant_progress(
|
|
714
|
+
*, force: bool = False
|
|
715
|
+
) -> StoredMessage | None:
|
|
710
716
|
nonlocal last_progress_flush_at
|
|
711
717
|
now = time.monotonic()
|
|
712
718
|
if (
|
|
713
|
-
|
|
719
|
+
not force
|
|
720
|
+
and last_progress_flush_at > 0
|
|
714
721
|
and now - last_progress_flush_at
|
|
715
722
|
< WORKSPACE_PROGRESS_FLUSH_INTERVAL_SECONDS
|
|
716
723
|
):
|
|
@@ -719,15 +726,16 @@ class WorkspaceRuntime:
|
|
|
719
726
|
last_progress_flush_at = now
|
|
720
727
|
return update_assistant_message("running", persist=True)
|
|
721
728
|
|
|
729
|
+
def has_tool_result(tool_id: str) -> bool:
|
|
730
|
+
tool = assistant_output.tools.get(tool_id)
|
|
731
|
+
return tool is not None and bool(tool.result)
|
|
732
|
+
|
|
722
733
|
try:
|
|
723
734
|
current_tool_id: str | None = None
|
|
724
735
|
turn_usage_info: TokenUsageInfo | None = None
|
|
725
736
|
current_output_index = 0
|
|
726
737
|
latest_usage_output_index: int | None = None
|
|
727
|
-
model_tool_specs =
|
|
728
|
-
*tool_specs(),
|
|
729
|
-
*list(self.mcp_manager.tool_specs()),
|
|
730
|
-
]
|
|
738
|
+
model_tool_specs = self.model_tool_specs()
|
|
731
739
|
if request_messages is None:
|
|
732
740
|
current_request_messages = self.request_messages_for_content(
|
|
733
741
|
state,
|
|
@@ -809,33 +817,6 @@ class WorkspaceRuntime:
|
|
|
809
817
|
else current_request_messages
|
|
810
818
|
)
|
|
811
819
|
|
|
812
|
-
async def review_tool_approval(request: ApprovalReviewRequest):
|
|
813
|
-
return await review_approval_request(
|
|
814
|
-
connection,
|
|
815
|
-
request.model_copy(
|
|
816
|
-
update={
|
|
817
|
-
"transcript": approval_transcript(next_messages),
|
|
818
|
-
"user_request": content,
|
|
819
|
-
}
|
|
820
|
-
),
|
|
821
|
-
completion=self.chat_completion,
|
|
822
|
-
)
|
|
823
|
-
|
|
824
|
-
async def tool_runner(
|
|
825
|
-
name: str,
|
|
826
|
-
arguments: dict[str, object],
|
|
827
|
-
context: ToolContext,
|
|
828
|
-
):
|
|
829
|
-
return await run_tool_with_path_permissions(
|
|
830
|
-
name,
|
|
831
|
-
arguments,
|
|
832
|
-
context,
|
|
833
|
-
review_approval=review_tool_approval,
|
|
834
|
-
writable_paths=[
|
|
835
|
-
Path(path.path) for path in self.store.read_writable_paths()
|
|
836
|
-
],
|
|
837
|
-
)
|
|
838
|
-
|
|
839
820
|
async def context_compactor(
|
|
840
821
|
conversation: Sequence[Mapping[str, object]],
|
|
841
822
|
) -> AgentContextUpdate | None:
|
|
@@ -882,16 +863,12 @@ class WorkspaceRuntime:
|
|
|
882
863
|
},
|
|
883
864
|
)
|
|
884
865
|
|
|
885
|
-
async for event in
|
|
886
|
-
|
|
866
|
+
async for event in self.agent_runtime.stream(
|
|
867
|
+
approval_transcript=approval_transcript(next_messages),
|
|
887
868
|
connection=connection,
|
|
888
869
|
context_compactor=context_compactor,
|
|
889
|
-
cwd=self.cwd,
|
|
890
|
-
extra_tool_runner=self.mcp_manager.run_tool,
|
|
891
|
-
extra_tool_specs=self.mcp_manager.tool_specs(),
|
|
892
|
-
extra_tool_title=self.mcp_manager.tool_title,
|
|
893
870
|
messages=current_request_messages,
|
|
894
|
-
|
|
871
|
+
user_request=content,
|
|
895
872
|
):
|
|
896
873
|
if not is_current_generation() or response.discard_on_cancel:
|
|
897
874
|
raise asyncio.CancelledError
|
|
@@ -936,6 +913,17 @@ class WorkspaceRuntime:
|
|
|
936
913
|
StoredToolItem.model_validate(tool)
|
|
937
914
|
)
|
|
938
915
|
snapshot_after_event = persist_assistant()
|
|
916
|
+
if event.event == "tool_update":
|
|
917
|
+
tool_id = event.data.get("id")
|
|
918
|
+
if (
|
|
919
|
+
isinstance(tool_id, str)
|
|
920
|
+
and tool_id in assistant_output.tools
|
|
921
|
+
):
|
|
922
|
+
had_result = has_tool_result(tool_id)
|
|
923
|
+
assistant_output.update_tool(tool_id, event.data)
|
|
924
|
+
snapshot_after_event = persist_assistant_progress(
|
|
925
|
+
force=not had_result
|
|
926
|
+
)
|
|
939
927
|
if event.event in {"tool_done", "tool_error"}:
|
|
940
928
|
tool_id = event.data.get("id")
|
|
941
929
|
if (
|