@ssweens/pi-leash 0.12.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +17 -0
- package/LICENSE +21 -0
- package/README.md +221 -0
- package/package.json +83 -0
- package/src/config.ts +285 -0
- package/src/hooks/index.ts +9 -0
- package/src/hooks/permission-gate.ts +925 -0
- package/src/hooks/policies.ts +315 -0
- package/src/index.ts +38 -0
- package/src/lib/executor.ts +280 -0
- package/src/lib/index.ts +16 -0
- package/src/lib/model-resolver.ts +47 -0
- package/src/lib/timing.ts +42 -0
- package/src/lib/types.ts +115 -0
- package/src/utils/events.ts +32 -0
- package/src/utils/glob-expander.ts +128 -0
- package/src/utils/matching.ts +111 -0
- package/src/utils/shell-utils.ts +139 -0
- package/src/vendor/aliou-sh/NOTICE.md +13 -0
- package/src/vendor/aliou-sh/ast.d.ts +186 -0
- package/src/vendor/aliou-sh/index.d.ts +3 -0
- package/src/vendor/aliou-sh/index.js +1397 -0
- package/src/vendor/aliou-sh/parse.d.ts +3 -0
- package/src/vendor/aliou-sh/upstream.package.json +55 -0
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
import { stat } from "node:fs/promises";
|
|
2
|
+
import { isAbsolute, relative, resolve } from "node:path";
|
|
3
|
+
import { parse } from "../vendor/aliou-sh/index.js";
|
|
4
|
+
import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
|
|
5
|
+
import type { PolicyRule, Protection, ResolvedConfig } from "../config";
|
|
6
|
+
import { emitBlocked } from "../utils/events";
|
|
7
|
+
import { expandGlob, hasGlobChars } from "../utils/glob-expander";
|
|
8
|
+
import {
|
|
9
|
+
type CompiledPattern,
|
|
10
|
+
compileFilePatterns,
|
|
11
|
+
normalizeFilePath,
|
|
12
|
+
} from "../utils/matching";
|
|
13
|
+
import { walkCommands, wordToString } from "../utils/shell-utils";
|
|
14
|
+
|
|
15
|
+
const DEFAULT_BLOCK_MESSAGES: Record<Protection, string> = {
|
|
16
|
+
noAccess:
|
|
17
|
+
"Accessing {file} is not allowed. This file is protected. Ask the user if changes are needed.",
|
|
18
|
+
readOnly:
|
|
19
|
+
"Writing to {file} is not allowed. This file is read-only. Use the read tool to inspect it instead of bash commands like cat or ls.",
|
|
20
|
+
none: "",
|
|
21
|
+
};
|
|
22
|
+
|
|
23
|
+
const BLOCKED_TOOLS: Record<Protection, Set<string>> = {
|
|
24
|
+
noAccess: new Set(["read", "write", "edit", "bash", "grep", "find", "ls"]),
|
|
25
|
+
readOnly: new Set(["write", "edit", "bash"]),
|
|
26
|
+
none: new Set(),
|
|
27
|
+
};
|
|
28
|
+
|
|
29
|
+
interface CompiledRule {
|
|
30
|
+
id: string;
|
|
31
|
+
protection: Protection;
|
|
32
|
+
patterns: CompiledPattern[];
|
|
33
|
+
allowedPatterns: CompiledPattern[];
|
|
34
|
+
onlyIfExists: boolean;
|
|
35
|
+
blockMessage: string;
|
|
36
|
+
enabled: boolean;
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
async function fileExists(cwd: string, filePath: string): Promise<boolean> {
|
|
40
|
+
try {
|
|
41
|
+
await stat(resolve(cwd, filePath));
|
|
42
|
+
return true;
|
|
43
|
+
} catch {
|
|
44
|
+
return false;
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
function protectionRank(protection: Protection): number {
|
|
49
|
+
switch (protection) {
|
|
50
|
+
case "none":
|
|
51
|
+
return 0;
|
|
52
|
+
case "readOnly":
|
|
53
|
+
return 1;
|
|
54
|
+
case "noAccess":
|
|
55
|
+
return 2;
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
function compileRules(rules: PolicyRule[]): CompiledRule[] {
|
|
60
|
+
const compiled: CompiledRule[] = [];
|
|
61
|
+
|
|
62
|
+
for (const rule of rules) {
|
|
63
|
+
const id = rule.id?.trim();
|
|
64
|
+
if (!id) {
|
|
65
|
+
console.warn("[pi-leash] Skipping policy rule without id.");
|
|
66
|
+
continue;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
if (
|
|
70
|
+
rule.protection !== "none" &&
|
|
71
|
+
rule.protection !== "readOnly" &&
|
|
72
|
+
rule.protection !== "noAccess"
|
|
73
|
+
) {
|
|
74
|
+
console.warn(
|
|
75
|
+
`[pi-leash] Skipping policy rule "${id}": invalid protection.`,
|
|
76
|
+
);
|
|
77
|
+
continue;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
const normalizedPatterns = (rule.patterns ?? []).filter(
|
|
81
|
+
(pattern) => pattern.pattern.trim().length > 0,
|
|
82
|
+
);
|
|
83
|
+
if (normalizedPatterns.length === 0) {
|
|
84
|
+
console.warn(
|
|
85
|
+
`[pi-leash] Skipping policy rule "${id}": missing non-empty patterns.`,
|
|
86
|
+
);
|
|
87
|
+
continue;
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
const normalizedAllowedPatterns = (rule.allowedPatterns ?? []).filter(
|
|
91
|
+
(pattern) => pattern.pattern.trim().length > 0,
|
|
92
|
+
);
|
|
93
|
+
|
|
94
|
+
compiled.push({
|
|
95
|
+
id,
|
|
96
|
+
protection: rule.protection,
|
|
97
|
+
patterns: compileFilePatterns(normalizedPatterns),
|
|
98
|
+
allowedPatterns: compileFilePatterns(normalizedAllowedPatterns),
|
|
99
|
+
onlyIfExists: rule.onlyIfExists ?? true,
|
|
100
|
+
blockMessage:
|
|
101
|
+
rule.blockMessage ?? DEFAULT_BLOCK_MESSAGES[rule.protection] ?? "",
|
|
102
|
+
enabled: rule.enabled ?? true,
|
|
103
|
+
});
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
return compiled;
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
function maybePathLike(token: string): boolean {
|
|
110
|
+
return (
|
|
111
|
+
token.includes("/") ||
|
|
112
|
+
token.includes(".") ||
|
|
113
|
+
token.startsWith("~") ||
|
|
114
|
+
token.startsWith("./") ||
|
|
115
|
+
token.startsWith("../")
|
|
116
|
+
);
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
function normalizeTargetForPolicy(filePath: string, cwd: string): string {
|
|
120
|
+
const absolute = resolve(cwd, filePath);
|
|
121
|
+
const rel = relative(cwd, absolute);
|
|
122
|
+
|
|
123
|
+
const candidate =
|
|
124
|
+
rel && !rel.startsWith("..") && !isAbsolute(rel) ? rel : absolute;
|
|
125
|
+
|
|
126
|
+
return normalizeFilePath(candidate);
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
function matchesAnyPolicyPattern(
|
|
130
|
+
filePath: string,
|
|
131
|
+
rules: CompiledRule[],
|
|
132
|
+
): boolean {
|
|
133
|
+
return rules.some(
|
|
134
|
+
(rule) =>
|
|
135
|
+
rule.enabled && rule.patterns.some((pattern) => pattern.test(filePath)),
|
|
136
|
+
);
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
async function expandCandidate(candidate: string): Promise<string[]> {
|
|
140
|
+
if (!hasGlobChars(candidate)) return [candidate];
|
|
141
|
+
|
|
142
|
+
const matches = await expandGlob(candidate);
|
|
143
|
+
if (matches.length > 0) return matches;
|
|
144
|
+
|
|
145
|
+
return [candidate];
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
async function extractBashFileTargets(
|
|
149
|
+
command: string,
|
|
150
|
+
rules: CompiledRule[],
|
|
151
|
+
cwd: string,
|
|
152
|
+
): Promise<string[]> {
|
|
153
|
+
const targets = new Set<string>();
|
|
154
|
+
|
|
155
|
+
const maybeAddTarget = async (candidate: string): Promise<void> => {
|
|
156
|
+
if (!candidate || candidate.startsWith("-")) return;
|
|
157
|
+
|
|
158
|
+
const expanded = await expandCandidate(candidate);
|
|
159
|
+
for (const file of expanded) {
|
|
160
|
+
const normalized = normalizeTargetForPolicy(file, cwd);
|
|
161
|
+
if (matchesAnyPolicyPattern(normalized, rules)) {
|
|
162
|
+
targets.add(normalized);
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
};
|
|
166
|
+
|
|
167
|
+
try {
|
|
168
|
+
const { ast } = parse(command);
|
|
169
|
+
const pending: Promise<void>[] = [];
|
|
170
|
+
|
|
171
|
+
walkCommands(ast, (cmd) => {
|
|
172
|
+
const words = (cmd.words ?? []).map(wordToString);
|
|
173
|
+
for (let i = 1; i < words.length; i++) {
|
|
174
|
+
const arg = words[i] as string;
|
|
175
|
+
pending.push(maybeAddTarget(arg));
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
for (const redir of cmd.redirects ?? []) {
|
|
179
|
+
const target = wordToString(redir.target);
|
|
180
|
+
pending.push(maybeAddTarget(target));
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
return false;
|
|
184
|
+
});
|
|
185
|
+
|
|
186
|
+
await Promise.all(pending);
|
|
187
|
+
|
|
188
|
+
return [...targets];
|
|
189
|
+
} catch {
|
|
190
|
+
const tokenRegex = /"([^"]+)"|'([^']+)'|`([^`]+)`|([^\s"'`<|;&]+)/g;
|
|
191
|
+
|
|
192
|
+
for (const match of command.matchAll(tokenRegex)) {
|
|
193
|
+
const token = match[1] ?? match[2] ?? match[3] ?? match[4] ?? "";
|
|
194
|
+
if (!token || token.startsWith("-") || !maybePathLike(token)) {
|
|
195
|
+
continue;
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
const expanded = await expandCandidate(token);
|
|
199
|
+
for (const file of expanded) {
|
|
200
|
+
const normalized = normalizeTargetForPolicy(file, cwd);
|
|
201
|
+
if (matchesAnyPolicyPattern(normalized, rules)) {
|
|
202
|
+
targets.add(normalized);
|
|
203
|
+
}
|
|
204
|
+
}
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
return [...targets];
|
|
208
|
+
}
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
async function getEffectiveProtection(
|
|
212
|
+
filePath: string,
|
|
213
|
+
compiledRules: CompiledRule[],
|
|
214
|
+
cwd: string,
|
|
215
|
+
): Promise<{
|
|
216
|
+
protection: Protection;
|
|
217
|
+
blockMessage: string;
|
|
218
|
+
ruleId: string;
|
|
219
|
+
} | null> {
|
|
220
|
+
let bestMatch: {
|
|
221
|
+
protection: Protection;
|
|
222
|
+
blockMessage: string;
|
|
223
|
+
ruleId: string;
|
|
224
|
+
rank: number;
|
|
225
|
+
} | null = null;
|
|
226
|
+
|
|
227
|
+
for (const rule of compiledRules) {
|
|
228
|
+
if (!rule.enabled) continue;
|
|
229
|
+
|
|
230
|
+
const matched = rule.patterns.some((pattern) => pattern.test(filePath));
|
|
231
|
+
if (!matched) continue;
|
|
232
|
+
|
|
233
|
+
const allowed = rule.allowedPatterns.some((pattern) =>
|
|
234
|
+
pattern.test(filePath),
|
|
235
|
+
);
|
|
236
|
+
if (allowed) continue;
|
|
237
|
+
|
|
238
|
+
if (rule.onlyIfExists && !(await fileExists(cwd, filePath))) continue;
|
|
239
|
+
|
|
240
|
+
const rank = protectionRank(rule.protection);
|
|
241
|
+
if (!bestMatch || rank > bestMatch.rank) {
|
|
242
|
+
bestMatch = {
|
|
243
|
+
protection: rule.protection,
|
|
244
|
+
blockMessage: rule.blockMessage,
|
|
245
|
+
ruleId: rule.id,
|
|
246
|
+
rank,
|
|
247
|
+
};
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
if (!bestMatch || bestMatch.protection === "none") return null;
|
|
252
|
+
|
|
253
|
+
return {
|
|
254
|
+
protection: bestMatch.protection,
|
|
255
|
+
blockMessage: bestMatch.blockMessage,
|
|
256
|
+
ruleId: bestMatch.ruleId,
|
|
257
|
+
};
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
function extractPathTarget(input: Record<string, unknown>): string[] {
|
|
261
|
+
const target = String(input.file_path ?? input.path ?? "").trim();
|
|
262
|
+
return target ? [target] : [];
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
export function setupPoliciesHook(pi: ExtensionAPI, config: ResolvedConfig) {
|
|
266
|
+
if (!config.features.policies) return;
|
|
267
|
+
|
|
268
|
+
const compiledRules = compileRules(config.policies.rules);
|
|
269
|
+
|
|
270
|
+
pi.on("tool_call", async (event, ctx) => {
|
|
271
|
+
const toolName = event.toolName;
|
|
272
|
+
let targets: string[] = [];
|
|
273
|
+
|
|
274
|
+
if (["read", "write", "edit", "grep", "find", "ls"].includes(toolName)) {
|
|
275
|
+
targets = extractPathTarget(event.input);
|
|
276
|
+
} else if (toolName === "bash") {
|
|
277
|
+
const command = String(event.input.command ?? "");
|
|
278
|
+
targets = await extractBashFileTargets(command, compiledRules, ctx.cwd);
|
|
279
|
+
} else {
|
|
280
|
+
return;
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
for (const target of targets) {
|
|
284
|
+
const normalizedTarget = normalizeTargetForPolicy(target, ctx.cwd);
|
|
285
|
+
|
|
286
|
+
const effective = await getEffectiveProtection(
|
|
287
|
+
normalizedTarget,
|
|
288
|
+
compiledRules,
|
|
289
|
+
ctx.cwd,
|
|
290
|
+
);
|
|
291
|
+
if (!effective) continue;
|
|
292
|
+
|
|
293
|
+
const blockedTools = BLOCKED_TOOLS[effective.protection];
|
|
294
|
+
if (!blockedTools.has(toolName)) continue;
|
|
295
|
+
|
|
296
|
+
ctx.ui.notify(
|
|
297
|
+
`Blocked ${toolName} on protected file: ${normalizedTarget} (${effective.ruleId})`,
|
|
298
|
+
"warning",
|
|
299
|
+
);
|
|
300
|
+
|
|
301
|
+
const reason = effective.blockMessage.replace("{file}", normalizedTarget);
|
|
302
|
+
|
|
303
|
+
emitBlocked(pi, {
|
|
304
|
+
feature: "policies",
|
|
305
|
+
toolName,
|
|
306
|
+
input: event.input,
|
|
307
|
+
reason,
|
|
308
|
+
});
|
|
309
|
+
|
|
310
|
+
return { block: true, reason };
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
return;
|
|
314
|
+
});
|
|
315
|
+
}
|
package/src/index.ts
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
|
|
2
|
+
import { loadConfig } from "./config";
|
|
3
|
+
import { setupLeashHooks } from "./hooks";
|
|
4
|
+
|
|
5
|
+
/**
|
|
6
|
+
* Pi Leash Extension
|
|
7
|
+
*
|
|
8
|
+
* Security hooks to prevent potentially dangerous operations:
|
|
9
|
+
* - policies: File access policies with per-rule protection levels
|
|
10
|
+
* - permission-gate: Prompts for confirmation on dangerous commands
|
|
11
|
+
* - sudo-mode: Secure password handling for sudo commands
|
|
12
|
+
*
|
|
13
|
+
* Configuration:
|
|
14
|
+
* - Global: ~/.pi/agent/settings/pi-leash.json
|
|
15
|
+
*
|
|
16
|
+
* Example config:
|
|
17
|
+
* {
|
|
18
|
+
* "enabled": true,
|
|
19
|
+
* "features": {
|
|
20
|
+
* "policies": true,
|
|
21
|
+
* "permissionGate": true
|
|
22
|
+
* },
|
|
23
|
+
* "permissionGate": {
|
|
24
|
+
* "sudoMode": {
|
|
25
|
+
* "enabled": true,
|
|
26
|
+
* "timeout": 30000,
|
|
27
|
+
* "preserveEnv": false
|
|
28
|
+
* }
|
|
29
|
+
* }
|
|
30
|
+
* }
|
|
31
|
+
*/
|
|
32
|
+
export default async function (pi: ExtensionAPI) {
|
|
33
|
+
const config = loadConfig();
|
|
34
|
+
|
|
35
|
+
if (!config.enabled) return;
|
|
36
|
+
|
|
37
|
+
setupLeashHooks(pi, config);
|
|
38
|
+
}
|
|
@@ -0,0 +1,280 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Core subagent executor.
|
|
3
|
+
*
|
|
4
|
+
* Uses createAgentSession from the SDK for all subagent patterns.
|
|
5
|
+
* Supports streaming text updates, tool execution tracking, and usage tracking.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
import type { AssistantMessage } from "@mariozechner/pi-ai";
|
|
9
|
+
import type { ExtensionContext } from "@mariozechner/pi-coding-agent";
|
|
10
|
+
import {
|
|
11
|
+
createAgentSession,
|
|
12
|
+
DefaultResourceLoader,
|
|
13
|
+
getAgentDir,
|
|
14
|
+
SessionManager,
|
|
15
|
+
SettingsManager,
|
|
16
|
+
} from "@mariozechner/pi-coding-agent";
|
|
17
|
+
import {
|
|
18
|
+
createExecutionTimer,
|
|
19
|
+
markExecutionEnd,
|
|
20
|
+
markExecutionStart,
|
|
21
|
+
} from "./timing";
|
|
22
|
+
import type {
|
|
23
|
+
OnTextUpdate,
|
|
24
|
+
OnToolUpdate,
|
|
25
|
+
SubagentConfig,
|
|
26
|
+
SubagentResult,
|
|
27
|
+
SubagentToolCall,
|
|
28
|
+
SubagentUsage,
|
|
29
|
+
} from "./types";
|
|
30
|
+
|
|
31
|
+
function generateRunId(name: string): string {
|
|
32
|
+
const slug =
|
|
33
|
+
name
|
|
34
|
+
.trim()
|
|
35
|
+
.toLowerCase()
|
|
36
|
+
.replace(/[^a-z0-9]+/g, "-") || "subagent";
|
|
37
|
+
const randomPart =
|
|
38
|
+
typeof globalThis.crypto?.randomUUID === "function"
|
|
39
|
+
? globalThis.crypto.randomUUID().slice(0, 8)
|
|
40
|
+
: Date.now().toString(36);
|
|
41
|
+
return `${slug}-${randomPart}`;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
/**
|
|
45
|
+
* Execute a subagent with the given configuration.
|
|
46
|
+
*
|
|
47
|
+
* @param config - Subagent configuration
|
|
48
|
+
* @param userMessage - The user's prompt
|
|
49
|
+
* @param ctx - Extension context
|
|
50
|
+
* @param onTextUpdate - Callback for streaming text
|
|
51
|
+
* @param signal - Abort signal
|
|
52
|
+
* @param onToolUpdate - Callback for tool execution updates
|
|
53
|
+
*/
|
|
54
|
+
export async function executeSubagent(
|
|
55
|
+
config: SubagentConfig,
|
|
56
|
+
userMessage: string,
|
|
57
|
+
ctx: ExtensionContext,
|
|
58
|
+
onTextUpdate?: OnTextUpdate,
|
|
59
|
+
signal?: AbortSignal,
|
|
60
|
+
onToolUpdate?: OnToolUpdate,
|
|
61
|
+
): Promise<SubagentResult> {
|
|
62
|
+
const runId = generateRunId(config.name);
|
|
63
|
+
const executionTimer = createExecutionTimer();
|
|
64
|
+
|
|
65
|
+
const agentDir = getAgentDir();
|
|
66
|
+
const settingsManager = SettingsManager.create(ctx.cwd, agentDir);
|
|
67
|
+
const resourceLoader = new DefaultResourceLoader({
|
|
68
|
+
cwd: ctx.cwd,
|
|
69
|
+
agentDir,
|
|
70
|
+
settingsManager,
|
|
71
|
+
noExtensions: true,
|
|
72
|
+
noPromptTemplates: true,
|
|
73
|
+
noThemes: true,
|
|
74
|
+
noSkills: true,
|
|
75
|
+
systemPromptOverride: () => config.systemPrompt,
|
|
76
|
+
appendSystemPromptOverride: () => [],
|
|
77
|
+
agentsFilesOverride: () => ({ agentsFiles: [] }),
|
|
78
|
+
skillsOverride: () => ({
|
|
79
|
+
skills: config.skills ?? [],
|
|
80
|
+
diagnostics: [],
|
|
81
|
+
}),
|
|
82
|
+
});
|
|
83
|
+
await resourceLoader.reload();
|
|
84
|
+
|
|
85
|
+
const { session } = await createAgentSession({
|
|
86
|
+
model: config.model,
|
|
87
|
+
tools: config.tools ?? [],
|
|
88
|
+
customTools: config.customTools ?? [],
|
|
89
|
+
sessionManager: SessionManager.inMemory(),
|
|
90
|
+
thinkingLevel: config.thinkingLevel ?? "low",
|
|
91
|
+
modelRegistry: ctx.modelRegistry,
|
|
92
|
+
resourceLoader,
|
|
93
|
+
});
|
|
94
|
+
|
|
95
|
+
let accumulated = "";
|
|
96
|
+
let finalResponse = "";
|
|
97
|
+
let aborted = false;
|
|
98
|
+
const toolCalls = new Map<string, SubagentToolCall>();
|
|
99
|
+
|
|
100
|
+
let toolsHaveStarted = false;
|
|
101
|
+
let toolsHaveCompleted = false;
|
|
102
|
+
|
|
103
|
+
const usage: SubagentUsage = {
|
|
104
|
+
inputTokens: 0,
|
|
105
|
+
outputTokens: 0,
|
|
106
|
+
cacheReadTokens: 0,
|
|
107
|
+
cacheWriteTokens: 0,
|
|
108
|
+
estimatedTokens: 0,
|
|
109
|
+
llmCost: 0,
|
|
110
|
+
toolCost: 0,
|
|
111
|
+
totalCost: 0,
|
|
112
|
+
};
|
|
113
|
+
|
|
114
|
+
const unsubscribe = session.subscribe((event) => {
|
|
115
|
+
if (event.type === "message_update") {
|
|
116
|
+
if (event.assistantMessageEvent.type === "text_delta") {
|
|
117
|
+
const delta = event.assistantMessageEvent.delta;
|
|
118
|
+
accumulated += delta;
|
|
119
|
+
|
|
120
|
+
if (toolsHaveCompleted) {
|
|
121
|
+
finalResponse += delta;
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
onTextUpdate?.(delta, accumulated);
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
if (event.type === "tool_execution_start") {
|
|
129
|
+
toolsHaveStarted = true;
|
|
130
|
+
toolsHaveCompleted = false;
|
|
131
|
+
finalResponse = "";
|
|
132
|
+
const toolCall: SubagentToolCall = {
|
|
133
|
+
toolCallId: event.toolCallId,
|
|
134
|
+
toolName: event.toolName,
|
|
135
|
+
args: event.args ?? {},
|
|
136
|
+
status: "running",
|
|
137
|
+
};
|
|
138
|
+
markExecutionStart(toolCall);
|
|
139
|
+
toolCalls.set(event.toolCallId, toolCall);
|
|
140
|
+
onToolUpdate?.([...toolCalls.values()]);
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
if (event.type === "tool_execution_update") {
|
|
144
|
+
const existing = toolCalls.get(event.toolCallId);
|
|
145
|
+
if (existing) {
|
|
146
|
+
existing.args = event.args ?? existing.args;
|
|
147
|
+
if (event.partialResult) {
|
|
148
|
+
existing.partialResult = event.partialResult as {
|
|
149
|
+
content: Array<{ type: string; text?: string }>;
|
|
150
|
+
details?: unknown;
|
|
151
|
+
};
|
|
152
|
+
}
|
|
153
|
+
onToolUpdate?.([...toolCalls.values()]);
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
if (event.type === "tool_execution_end") {
|
|
158
|
+
const existing = toolCalls.get(event.toolCallId);
|
|
159
|
+
if (existing) {
|
|
160
|
+
existing.status = event.isError ? "error" : "done";
|
|
161
|
+
existing.result = event.result;
|
|
162
|
+
markExecutionEnd(existing);
|
|
163
|
+
if (event.isError && event.result) {
|
|
164
|
+
existing.error =
|
|
165
|
+
typeof event.result === "string"
|
|
166
|
+
? event.result
|
|
167
|
+
: JSON.stringify(event.result);
|
|
168
|
+
}
|
|
169
|
+
onToolUpdate?.([...toolCalls.values()]);
|
|
170
|
+
|
|
171
|
+
const resultDetails = event.result?.details as
|
|
172
|
+
| { cost?: number }
|
|
173
|
+
| undefined;
|
|
174
|
+
if (resultDetails?.cost !== undefined) {
|
|
175
|
+
usage.toolCost = (usage.toolCost ?? 0) + resultDetails.cost;
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
const allDone = [...toolCalls.values()].every(
|
|
180
|
+
(tc) => tc.status === "done" || tc.status === "error",
|
|
181
|
+
);
|
|
182
|
+
if (allDone) {
|
|
183
|
+
toolsHaveCompleted = true;
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
if (event.type === "turn_end") {
|
|
188
|
+
const msg = event.message;
|
|
189
|
+
if (msg.role === "assistant") {
|
|
190
|
+
const assistantMsg = msg as AssistantMessage;
|
|
191
|
+
const msgUsage = assistantMsg.usage;
|
|
192
|
+
if (msgUsage) {
|
|
193
|
+
usage.inputTokens = (usage.inputTokens ?? 0) + msgUsage.input;
|
|
194
|
+
usage.outputTokens = (usage.outputTokens ?? 0) + msgUsage.output;
|
|
195
|
+
usage.cacheReadTokens =
|
|
196
|
+
(usage.cacheReadTokens ?? 0) + msgUsage.cacheRead;
|
|
197
|
+
usage.cacheWriteTokens =
|
|
198
|
+
(usage.cacheWriteTokens ?? 0) + msgUsage.cacheWrite;
|
|
199
|
+
usage.llmCost = (usage.llmCost ?? 0) + msgUsage.cost.total;
|
|
200
|
+
}
|
|
201
|
+
}
|
|
202
|
+
}
|
|
203
|
+
});
|
|
204
|
+
|
|
205
|
+
if (signal) {
|
|
206
|
+
if (signal.aborted) {
|
|
207
|
+
unsubscribe();
|
|
208
|
+
session.dispose();
|
|
209
|
+
return {
|
|
210
|
+
content: "",
|
|
211
|
+
aborted: true,
|
|
212
|
+
toolCalls: [],
|
|
213
|
+
totalDurationMs: executionTimer.getDurationMs(),
|
|
214
|
+
runId,
|
|
215
|
+
usage,
|
|
216
|
+
};
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
signal.addEventListener(
|
|
220
|
+
"abort",
|
|
221
|
+
() => {
|
|
222
|
+
session.abort();
|
|
223
|
+
aborted = true;
|
|
224
|
+
},
|
|
225
|
+
{ once: true },
|
|
226
|
+
);
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
let error: string | undefined;
|
|
230
|
+
|
|
231
|
+
try {
|
|
232
|
+
await session.prompt(userMessage);
|
|
233
|
+
} catch (err) {
|
|
234
|
+
if (signal?.aborted) {
|
|
235
|
+
aborted = true;
|
|
236
|
+
} else {
|
|
237
|
+
error =
|
|
238
|
+
err instanceof Error
|
|
239
|
+
? err.message
|
|
240
|
+
: typeof err === "string"
|
|
241
|
+
? err
|
|
242
|
+
: JSON.stringify(err);
|
|
243
|
+
}
|
|
244
|
+
} finally {
|
|
245
|
+
unsubscribe();
|
|
246
|
+
session.dispose();
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
const responseText = toolsHaveStarted ? finalResponse : accumulated;
|
|
250
|
+
const cleanedContent = filterThinkingTags(responseText);
|
|
251
|
+
|
|
252
|
+
const totalRealTokens =
|
|
253
|
+
(usage.inputTokens ?? 0) +
|
|
254
|
+
(usage.outputTokens ?? 0) +
|
|
255
|
+
(usage.cacheReadTokens ?? 0) +
|
|
256
|
+
(usage.cacheWriteTokens ?? 0);
|
|
257
|
+
usage.estimatedTokens =
|
|
258
|
+
totalRealTokens > 0
|
|
259
|
+
? totalRealTokens
|
|
260
|
+
: Math.round(cleanedContent.length / 4);
|
|
261
|
+
|
|
262
|
+
usage.totalCost = (usage.llmCost ?? 0) + (usage.toolCost ?? 0);
|
|
263
|
+
|
|
264
|
+
return {
|
|
265
|
+
content: cleanedContent,
|
|
266
|
+
aborted,
|
|
267
|
+
toolCalls: [...toolCalls.values()],
|
|
268
|
+
totalDurationMs: executionTimer.getDurationMs(),
|
|
269
|
+
error,
|
|
270
|
+
runId,
|
|
271
|
+
usage,
|
|
272
|
+
};
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
/**
|
|
276
|
+
* Filter out <thinking>...</thinking> tags from text.
|
|
277
|
+
*/
|
|
278
|
+
export function filterThinkingTags(text: string): string {
|
|
279
|
+
return text.replace(/<thinking>[\s\S]*?<\/thinking>\s*/g, "");
|
|
280
|
+
}
|
package/src/lib/index.ts
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
export { executeSubagent, filterThinkingTags } from "./executor";
|
|
2
|
+
export { resolveModel } from "./model-resolver";
|
|
3
|
+
export {
|
|
4
|
+
createExecutionTimer,
|
|
5
|
+
markExecutionEnd,
|
|
6
|
+
markExecutionStart,
|
|
7
|
+
type TimedExecution,
|
|
8
|
+
} from "./timing";
|
|
9
|
+
export type {
|
|
10
|
+
OnTextUpdate,
|
|
11
|
+
OnToolUpdate,
|
|
12
|
+
SubagentConfig,
|
|
13
|
+
SubagentResult,
|
|
14
|
+
SubagentToolCall,
|
|
15
|
+
SubagentUsage,
|
|
16
|
+
} from "./types";
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Model resolution helper for subagents.
|
|
3
|
+
*
|
|
4
|
+
* Resolves a model by provider + ID from the model registry.
|
|
5
|
+
*/
|
|
6
|
+
|
|
7
|
+
import type { Model } from "@mariozechner/pi-ai";
|
|
8
|
+
import type { ExtensionContext } from "@mariozechner/pi-coding-agent";
|
|
9
|
+
|
|
10
|
+
/**
|
|
11
|
+
* Find a model by provider and ID.
|
|
12
|
+
*
|
|
13
|
+
* @param provider - Provider name (e.g., "openrouter", "anthropic", "openai-codex")
|
|
14
|
+
* @param modelId - Model ID (e.g., "anthropic/claude-haiku-4.5")
|
|
15
|
+
* @param ctx - Extension context with modelRegistry
|
|
16
|
+
* @returns The resolved model
|
|
17
|
+
* @throws Error if model not found or API key not configured
|
|
18
|
+
*/
|
|
19
|
+
export function resolveModel(
|
|
20
|
+
provider: string,
|
|
21
|
+
modelId: string,
|
|
22
|
+
ctx: ExtensionContext,
|
|
23
|
+
// biome-ignore lint/suspicious/noExplicitAny: Model type requires any for generic API
|
|
24
|
+
): Model<any> {
|
|
25
|
+
const available = ctx.modelRegistry.getAvailable();
|
|
26
|
+
const model = available.find(
|
|
27
|
+
(m) => m.id === modelId && m.provider === provider,
|
|
28
|
+
);
|
|
29
|
+
|
|
30
|
+
if (model) {
|
|
31
|
+
return model;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
// Check if the model exists but the API key is missing
|
|
35
|
+
const all = ctx.modelRegistry.getAll();
|
|
36
|
+
const existsWithoutKey = all.some(
|
|
37
|
+
(m) => m.id === modelId && m.provider === provider,
|
|
38
|
+
);
|
|
39
|
+
|
|
40
|
+
if (existsWithoutKey) {
|
|
41
|
+
throw new Error(
|
|
42
|
+
`Model "${modelId}" exists on ${provider} but no valid API key is configured.`,
|
|
43
|
+
);
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
throw new Error(`Model "${modelId}" not found on provider "${provider}".`);
|
|
47
|
+
}
|