@strav/brain 0.3.22 → 0.3.24
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +3 -3
- package/src/agent.ts +20 -0
- package/src/helpers.ts +144 -18
- package/src/index.ts +3 -0
- package/src/types.ts +41 -1
- package/src/workflow.ts +20 -1
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@strav/brain",
|
|
3
|
-
"version": "0.3.
|
|
3
|
+
"version": "0.3.24",
|
|
4
4
|
"type": "module",
|
|
5
5
|
"description": "AI module for the Strav framework",
|
|
6
6
|
"license": "MIT",
|
|
@@ -15,10 +15,10 @@
|
|
|
15
15
|
"CHANGELOG.md"
|
|
16
16
|
],
|
|
17
17
|
"peerDependencies": {
|
|
18
|
-
"@strav/kernel": "0.3.
|
|
18
|
+
"@strav/kernel": "0.3.24"
|
|
19
19
|
},
|
|
20
20
|
"dependencies": {
|
|
21
|
-
"@strav/workflow": "0.3.
|
|
21
|
+
"@strav/workflow": "0.3.24",
|
|
22
22
|
"zod": "^3.25 || ^4.0"
|
|
23
23
|
},
|
|
24
24
|
"scripts": {
|
package/src/agent.ts
CHANGED
|
@@ -62,6 +62,26 @@ export abstract class Agent {
|
|
|
62
62
|
/** Called when the model requests a tool call, before execution. */
|
|
63
63
|
onToolCall?(call: ToolCall): void | Promise<void>
|
|
64
64
|
|
|
65
|
+
/**
|
|
66
|
+
* Called before a tool is executed. Return `true` to suspend the agent loop
|
|
67
|
+
* before running this tool call; the runner will return a `SuspendedRun`
|
|
68
|
+
* with a JSON-serializable snapshot of the loop state. Resume later via
|
|
69
|
+
* `AgentRunner.resume(state, toolResults)` once the tool result is known.
|
|
70
|
+
*
|
|
71
|
+
* This is a policy-free primitive: the framework does not attach meaning
|
|
72
|
+
* to suspension. Integrators can use it to gate mutating tools on human
|
|
73
|
+
* approval, dispatch a tool to an external worker, rate-limit, etc.
|
|
74
|
+
*
|
|
75
|
+
* When suspension occurs mid-batch, the triggering call and any remaining
|
|
76
|
+
* unprocessed calls in the same batch are captured together in
|
|
77
|
+
* `pendingToolCalls` so the provider's tool_use/tool_result contract stays
|
|
78
|
+
* balanced on resume.
|
|
79
|
+
*/
|
|
80
|
+
shouldSuspend?(
|
|
81
|
+
call: ToolCall,
|
|
82
|
+
context: Record<string, unknown>
|
|
83
|
+
): boolean | Promise<boolean>
|
|
84
|
+
|
|
65
85
|
/** Called after a tool finishes execution. */
|
|
66
86
|
onToolResult?(call: ToolCallRecord): void | Promise<void>
|
|
67
87
|
|
package/src/helpers.ts
CHANGED
|
@@ -20,6 +20,9 @@ import type {
|
|
|
20
20
|
Usage,
|
|
21
21
|
JsonSchema,
|
|
22
22
|
SerializedThread,
|
|
23
|
+
SerializedAgentState,
|
|
24
|
+
SuspendedRun,
|
|
25
|
+
ToolCallResult,
|
|
23
26
|
} from './types.ts'
|
|
24
27
|
|
|
25
28
|
// ── Shared tool executor ─────────────────────────────────────────────────────
|
|
@@ -257,8 +260,57 @@ export class AgentRunner<T extends Agent = Agent> {
|
|
|
257
260
|
return this
|
|
258
261
|
}
|
|
259
262
|
|
|
260
|
-
/** Run the agent to completion. */
|
|
261
|
-
async run(): Promise<AgentResult> {
|
|
263
|
+
/** Run the agent to completion (or until it suspends on a tool call). */
|
|
264
|
+
async run(): Promise<AgentResult | SuspendedRun> {
|
|
265
|
+
return this.runFromState(null)
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
/**
|
|
269
|
+
* Resume a previously suspended agent run with the results of the pending
|
|
270
|
+
* tool calls. Returns a completed `AgentResult` — or another `SuspendedRun`
|
|
271
|
+
* if the continuation itself hits another suspending tool call.
|
|
272
|
+
*
|
|
273
|
+
* `toolResults` must contain one entry per call in the original
|
|
274
|
+
* `SuspendedRun.pendingToolCalls`, matched by `toolCallId`. To signal a
|
|
275
|
+
* rejection, pass a string or object describing the error as the
|
|
276
|
+
* `result` — the model sees it as a normal tool failure and adapts.
|
|
277
|
+
*/
|
|
278
|
+
async resume(
|
|
279
|
+
state: SerializedAgentState,
|
|
280
|
+
toolResults: ToolCallResult[]
|
|
281
|
+
): Promise<AgentResult | SuspendedRun> {
|
|
282
|
+
const hydratedMessages: Message[] = [...state.messages]
|
|
283
|
+
const hydratedToolCalls: ToolCallRecord[] = [...state.allToolCalls]
|
|
284
|
+
|
|
285
|
+
for (const r of toolResults) {
|
|
286
|
+
const originalCall = findToolCallInMessages(hydratedMessages, r.toolCallId)
|
|
287
|
+
|
|
288
|
+
hydratedMessages.push({
|
|
289
|
+
role: 'tool',
|
|
290
|
+
toolCallId: r.toolCallId,
|
|
291
|
+
content: typeof r.result === 'string' ? r.result : JSON.stringify(r.result),
|
|
292
|
+
})
|
|
293
|
+
|
|
294
|
+
hydratedToolCalls.push({
|
|
295
|
+
name: originalCall?.name ?? '',
|
|
296
|
+
arguments: originalCall?.arguments ?? {},
|
|
297
|
+
result: r.result,
|
|
298
|
+
duration: 0,
|
|
299
|
+
})
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
return this.runFromState({
|
|
303
|
+
messages: hydratedMessages,
|
|
304
|
+
allToolCalls: hydratedToolCalls,
|
|
305
|
+
totalUsage: { ...state.totalUsage },
|
|
306
|
+
iterations: state.iterations,
|
|
307
|
+
})
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
/** Shared loop body. Used by both `run()` (fresh state) and `resume()` (restored state). */
|
|
311
|
+
private async runFromState(
|
|
312
|
+
initial: SerializedAgentState | null
|
|
313
|
+
): Promise<AgentResult | SuspendedRun> {
|
|
262
314
|
const agent = new this.AgentClass()
|
|
263
315
|
const config = BrainManager.config
|
|
264
316
|
|
|
@@ -274,11 +326,13 @@ export class AgentRunner<T extends Agent = Agent> {
|
|
|
274
326
|
const maxTokens = agent.maxTokens ?? config.maxTokens
|
|
275
327
|
const temperature = agent.temperature ?? config.temperature
|
|
276
328
|
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
329
|
+
if (!initial) {
|
|
330
|
+
try {
|
|
331
|
+
await agent.onStart?.(this._input, this._context)
|
|
332
|
+
} catch (err) {
|
|
333
|
+
await agent.onError?.(err instanceof Error ? err : new Error(String(err)))
|
|
334
|
+
throw err
|
|
335
|
+
}
|
|
282
336
|
}
|
|
283
337
|
|
|
284
338
|
// Build system prompt with context interpolation
|
|
@@ -295,10 +349,14 @@ export class AgentRunner<T extends Agent = Agent> {
|
|
|
295
349
|
schema = zodToJsonSchema(agent.output)
|
|
296
350
|
}
|
|
297
351
|
|
|
298
|
-
const messages: Message[] =
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
352
|
+
const messages: Message[] = initial
|
|
353
|
+
? [...initial.messages]
|
|
354
|
+
: [{ role: 'user', content: this._input }]
|
|
355
|
+
const allToolCalls: ToolCallRecord[] = initial ? [...initial.allToolCalls] : []
|
|
356
|
+
const totalUsage: Usage = initial
|
|
357
|
+
? { ...initial.totalUsage }
|
|
358
|
+
: { inputTokens: 0, outputTokens: 0, totalTokens: 0 }
|
|
359
|
+
let iterations = initial?.iterations ?? 0
|
|
302
360
|
|
|
303
361
|
// Tool loop
|
|
304
362
|
while (iterations < maxIterations) {
|
|
@@ -374,8 +432,16 @@ export class AgentRunner<T extends Agent = Agent> {
|
|
|
374
432
|
return result
|
|
375
433
|
}
|
|
376
434
|
|
|
377
|
-
// Execute tool calls
|
|
378
|
-
await this.executeTools(
|
|
435
|
+
// Execute tool calls (or suspend if the agent vetos)
|
|
436
|
+
const suspension = await this.executeTools(
|
|
437
|
+
agent,
|
|
438
|
+
response.toolCalls,
|
|
439
|
+
messages,
|
|
440
|
+
allToolCalls,
|
|
441
|
+
totalUsage,
|
|
442
|
+
iterations
|
|
443
|
+
)
|
|
444
|
+
if (suspension) return suspension
|
|
379
445
|
}
|
|
380
446
|
|
|
381
447
|
// Max iterations reached — return what we have
|
|
@@ -519,8 +585,28 @@ export class AgentRunner<T extends Agent = Agent> {
|
|
|
519
585
|
return
|
|
520
586
|
}
|
|
521
587
|
|
|
522
|
-
// Execute tools and yield events
|
|
523
|
-
for (
|
|
588
|
+
// Execute tools and yield events (or suspend if the agent vetos)
|
|
589
|
+
for (let i = 0; i < toolCalls.length; i++) {
|
|
590
|
+
const toolCall = toolCalls[i]!
|
|
591
|
+
|
|
592
|
+
if (agent.shouldSuspend) {
|
|
593
|
+
const suspend = await agent.shouldSuspend(toolCall, this._context)
|
|
594
|
+
if (suspend) {
|
|
595
|
+
const suspended: SuspendedRun = {
|
|
596
|
+
status: 'suspended',
|
|
597
|
+
pendingToolCalls: toolCalls.slice(i),
|
|
598
|
+
state: {
|
|
599
|
+
messages: [...messages],
|
|
600
|
+
allToolCalls: [...allToolCalls],
|
|
601
|
+
totalUsage: { ...totalUsage },
|
|
602
|
+
iterations,
|
|
603
|
+
},
|
|
604
|
+
}
|
|
605
|
+
yield { type: 'suspended', suspended }
|
|
606
|
+
return
|
|
607
|
+
}
|
|
608
|
+
}
|
|
609
|
+
|
|
524
610
|
await agent.onToolCall?.(toolCall)
|
|
525
611
|
|
|
526
612
|
const start = performance.now()
|
|
@@ -565,9 +651,31 @@ export class AgentRunner<T extends Agent = Agent> {
|
|
|
565
651
|
agent: Agent,
|
|
566
652
|
toolCalls: ToolCall[],
|
|
567
653
|
messages: Message[],
|
|
568
|
-
allToolCalls: ToolCallRecord[]
|
|
569
|
-
|
|
570
|
-
|
|
654
|
+
allToolCalls: ToolCallRecord[],
|
|
655
|
+
totalUsage: Usage,
|
|
656
|
+
iterations: number
|
|
657
|
+
): Promise<SuspendedRun | null> {
|
|
658
|
+
for (let i = 0; i < toolCalls.length; i++) {
|
|
659
|
+
const toolCall = toolCalls[i]!
|
|
660
|
+
|
|
661
|
+
if (agent.shouldSuspend) {
|
|
662
|
+
const suspend = await agent.shouldSuspend(toolCall, this._context)
|
|
663
|
+
if (suspend) {
|
|
664
|
+
// Capture this call + all remaining calls in the batch so the
|
|
665
|
+
// provider's tool_use/tool_result contract stays balanced on resume.
|
|
666
|
+
return {
|
|
667
|
+
status: 'suspended',
|
|
668
|
+
pendingToolCalls: toolCalls.slice(i),
|
|
669
|
+
state: {
|
|
670
|
+
messages: [...messages],
|
|
671
|
+
allToolCalls: [...allToolCalls],
|
|
672
|
+
totalUsage: { ...totalUsage },
|
|
673
|
+
iterations,
|
|
674
|
+
},
|
|
675
|
+
}
|
|
676
|
+
}
|
|
677
|
+
}
|
|
678
|
+
|
|
571
679
|
await agent.onToolCall?.(toolCall)
|
|
572
680
|
|
|
573
681
|
const start = performance.now()
|
|
@@ -585,7 +693,25 @@ export class AgentRunner<T extends Agent = Agent> {
|
|
|
585
693
|
|
|
586
694
|
messages.push(message)
|
|
587
695
|
}
|
|
696
|
+
return null
|
|
697
|
+
}
|
|
698
|
+
}
|
|
699
|
+
|
|
700
|
+
// ── Helpers for resume ───────────────────────────────────────────────────────
|
|
701
|
+
|
|
702
|
+
/**
|
|
703
|
+
* Walk `messages` backwards and find the `ToolCall` (on an assistant message)
|
|
704
|
+
* whose id matches `toolCallId`. Returns undefined if not found.
|
|
705
|
+
*/
|
|
706
|
+
function findToolCallInMessages(messages: Message[], toolCallId: string): ToolCall | undefined {
|
|
707
|
+
for (let i = messages.length - 1; i >= 0; i--) {
|
|
708
|
+
const m = messages[i]!
|
|
709
|
+
if (m.role === 'assistant' && m.toolCalls) {
|
|
710
|
+
const call = m.toolCalls.find(c => c.id === toolCallId)
|
|
711
|
+
if (call) return call
|
|
712
|
+
}
|
|
588
713
|
}
|
|
714
|
+
return undefined
|
|
589
715
|
}
|
|
590
716
|
|
|
591
717
|
// ── Thread ────────────────────────────────────────────────────────────────────
|
package/src/index.ts
CHANGED
|
@@ -34,6 +34,9 @@ export type {
|
|
|
34
34
|
BeforeHook,
|
|
35
35
|
AfterHook,
|
|
36
36
|
SerializedThread,
|
|
37
|
+
SerializedAgentState,
|
|
38
|
+
SuspendedRun,
|
|
39
|
+
ToolCallResult,
|
|
37
40
|
OutputSchema,
|
|
38
41
|
} from './types.ts'
|
|
39
42
|
export type { ChatOptions, GenerateOptions, GenerateResult, EmbedOptions } from './helpers.ts'
|
package/src/types.ts
CHANGED
|
@@ -112,11 +112,51 @@ export interface AgentResult<T = any> {
|
|
|
112
112
|
}
|
|
113
113
|
|
|
114
114
|
export interface AgentEvent {
|
|
115
|
-
type: 'text' | 'tool_start' | 'tool_result' | 'iteration' | 'done'
|
|
115
|
+
type: 'text' | 'tool_start' | 'tool_result' | 'iteration' | 'done' | 'suspended'
|
|
116
116
|
text?: string
|
|
117
117
|
toolCall?: ToolCallRecord
|
|
118
118
|
iteration?: number
|
|
119
119
|
result?: AgentResult
|
|
120
|
+
suspended?: SuspendedRun
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
// ── Suspend / Resume ─────────────────────────────────────────────────────────
|
|
124
|
+
|
|
125
|
+
/**
|
|
126
|
+
* A JSON-serializable snapshot of an agent loop at the moment it suspended.
|
|
127
|
+
*
|
|
128
|
+
* All fields are plain data — no functions, class instances, or cycles — so
|
|
129
|
+
* the snapshot can be stringified, stored across a process boundary, and
|
|
130
|
+
* later passed to `AgentRunner.resume()` to continue the run.
|
|
131
|
+
*/
|
|
132
|
+
export interface SerializedAgentState {
|
|
133
|
+
messages: Message[]
|
|
134
|
+
allToolCalls: ToolCallRecord[]
|
|
135
|
+
totalUsage: Usage
|
|
136
|
+
iterations: number
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
/**
|
|
140
|
+
* Result of an agent run that was suspended before executing one or more
|
|
141
|
+
* tool calls. The integrator is expected to obtain tool results out-of-band
|
|
142
|
+
* (human approval, external system, queued job, etc.) and call
|
|
143
|
+
* `AgentRunner.resume(state, toolResults)` to continue.
|
|
144
|
+
*
|
|
145
|
+
* `pendingToolCalls` contains the pending call that triggered suspension
|
|
146
|
+
* plus any subsequent tool calls from the same batch that have not been
|
|
147
|
+
* executed. Results must be supplied for each of them on resume so the
|
|
148
|
+
* conversation remains well-formed for the provider.
|
|
149
|
+
*/
|
|
150
|
+
export interface SuspendedRun {
|
|
151
|
+
status: 'suspended'
|
|
152
|
+
pendingToolCalls: ToolCall[]
|
|
153
|
+
state: SerializedAgentState
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
/** Result of a pending tool call, supplied to `AgentRunner.resume()`. */
|
|
157
|
+
export interface ToolCallResult {
|
|
158
|
+
toolCallId: string
|
|
159
|
+
result: unknown
|
|
120
160
|
}
|
|
121
161
|
|
|
122
162
|
// ── Workflow ──────────────────────────────────────────────────────────────────
|
package/src/workflow.ts
CHANGED
|
@@ -2,7 +2,7 @@ import { Workflow as BaseWorkflow } from '@strav/workflow'
|
|
|
2
2
|
import type { WorkflowContext as BaseContext } from '@strav/workflow'
|
|
3
3
|
import { AgentRunner } from './helpers.ts'
|
|
4
4
|
import type { Agent } from './agent.ts'
|
|
5
|
-
import type { AgentResult, WorkflowResult, Usage } from './types.ts'
|
|
5
|
+
import type { AgentResult, SuspendedRun, WorkflowResult, Usage } from './types.ts'
|
|
6
6
|
|
|
7
7
|
// ── AI Workflow Context ─────────────────────────────────────────────────────
|
|
8
8
|
|
|
@@ -27,6 +27,20 @@ function addUsage(total: Usage, add: Usage): void {
|
|
|
27
27
|
total.totalTokens += add.totalTokens
|
|
28
28
|
}
|
|
29
29
|
|
|
30
|
+
// Workflow orchestration runs agents to completion; suspension is a standalone
|
|
31
|
+
// primitive on AgentRunner. Surface a clear error rather than silently swallowing.
|
|
32
|
+
function assertCompleted(
|
|
33
|
+
result: AgentResult | SuspendedRun,
|
|
34
|
+
stepName: string
|
|
35
|
+
): asserts result is AgentResult {
|
|
36
|
+
if ((result as SuspendedRun).status === 'suspended') {
|
|
37
|
+
throw new Error(
|
|
38
|
+
`Workflow step "${stepName}" suspended — Workflow does not support agent suspension. ` +
|
|
39
|
+
`Use AgentRunner.run()/resume() directly, or ensure workflow agents don't define shouldSuspend.`
|
|
40
|
+
)
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
|
|
30
44
|
// ── Workflow Builder ────────────────────────────────────────────────────────
|
|
31
45
|
|
|
32
46
|
/**
|
|
@@ -60,6 +74,7 @@ export class Workflow {
|
|
|
60
74
|
this.pipeline.step(name, async (ctx: BaseContext) => {
|
|
61
75
|
const inputText = resolveInput(mapInput, ctx)
|
|
62
76
|
const result = await new AgentRunner(agent).input(inputText).run()
|
|
77
|
+
assertCompleted(result, name)
|
|
63
78
|
addUsage(this.totalUsage, result.usage)
|
|
64
79
|
return result
|
|
65
80
|
})
|
|
@@ -81,6 +96,7 @@ export class Workflow {
|
|
|
81
96
|
handler: async (ctx: BaseContext) => {
|
|
82
97
|
const inputText = resolveInput(a.mapInput, ctx)
|
|
83
98
|
const result = await new AgentRunner(a.agent).input(inputText).run()
|
|
99
|
+
assertCompleted(result, `${name}.${a.name}`)
|
|
84
100
|
addUsage(this.totalUsage, result.usage)
|
|
85
101
|
return result
|
|
86
102
|
},
|
|
@@ -104,6 +120,7 @@ export class Workflow {
|
|
|
104
120
|
this.pipeline.step(`${name}:router`, async (ctx: BaseContext) => {
|
|
105
121
|
const inputText = resolveInput(mapInput, ctx)
|
|
106
122
|
const result = await new AgentRunner(router).input(inputText).run()
|
|
123
|
+
assertCompleted(result, `${name}:router`)
|
|
107
124
|
addUsage(this.totalUsage, result.usage)
|
|
108
125
|
return result
|
|
109
126
|
})
|
|
@@ -121,6 +138,7 @@ export class Workflow {
|
|
|
121
138
|
async (ctx: BaseContext) => {
|
|
122
139
|
const inputText = resolveInput(mapInput, ctx)
|
|
123
140
|
const result = await new AgentRunner(BranchAgent).input(inputText).run()
|
|
141
|
+
assertCompleted(result, `${name}:${key}`)
|
|
124
142
|
addUsage(this.totalUsage, result.usage)
|
|
125
143
|
return result
|
|
126
144
|
},
|
|
@@ -148,6 +166,7 @@ export class Workflow {
|
|
|
148
166
|
name,
|
|
149
167
|
async (input: unknown, _ctx: BaseContext) => {
|
|
150
168
|
const result = await new AgentRunner(agent).input(String(input)).run()
|
|
169
|
+
assertCompleted(result, name)
|
|
151
170
|
addUsage(this.totalUsage, result.usage)
|
|
152
171
|
return result
|
|
153
172
|
},
|