@braintrust/pi-extension 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,598 @@
1
+ import { mkdirSync, mkdtempSync, rmSync } from "node:fs";
2
+ import { tmpdir } from "node:os";
3
+ import { join } from "node:path";
4
+ import {
5
+ createAssistantMessageEventStream,
6
+ registerApiProvider,
7
+ type Api,
8
+ type AssistantMessage,
9
+ type AssistantMessageEventStream,
10
+ type Context,
11
+ type Model,
12
+ type SimpleStreamOptions,
13
+ type ToolCall,
14
+ } from "@mariozechner/pi-ai";
15
+ import * as piCodingAgent from "@mariozechner/pi-coding-agent";
16
+ import {
17
+ createAgentSession,
18
+ DefaultResourceLoader,
19
+ SessionManager,
20
+ type ExtensionAPI,
21
+ } from "@mariozechner/pi-coding-agent";
22
+ import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
23
+ import braintrustPiExtension from "./index.ts";
24
+
25
+ const mockState = vi.hoisted(() => ({
26
+ startSpans: [] as Array<Record<string, unknown>>,
27
+ logSpans: [] as Array<Record<string, unknown>>,
28
+ endSpans: [] as Array<Record<string, unknown>>,
29
+ updateSpans: [] as Array<Record<string, unknown>>,
30
+ initializeCalls: 0,
31
+ flushCalls: 0,
32
+ failInitialize: false,
33
+ }));
34
+
35
+ vi.mock("./client.ts", () => {
36
+ class MockBraintrustClient {
37
+ async initialize(): Promise<void> {
38
+ mockState.initializeCalls += 1;
39
+ if (mockState.failInitialize) {
40
+ throw new Error("simulated Braintrust init failure");
41
+ }
42
+ }
43
+
44
+ startSpan(args: Record<string, unknown>): Record<string, unknown> {
45
+ mockState.startSpans.push(args);
46
+ return {
47
+ id: `record-${String(args.spanId)}`,
48
+ spanId: args.spanId,
49
+ rootSpanId: args.rootSpanId,
50
+ };
51
+ }
52
+
53
+ logSpan(span: Record<string, unknown> | undefined, event: Record<string, unknown>): void {
54
+ mockState.logSpans.push({ span, event });
55
+ }
56
+
57
+ endSpan(span: Record<string, unknown> | undefined, endedAt?: number): void {
58
+ mockState.endSpans.push({ span, endedAt });
59
+ }
60
+
61
+ getSpanLink(span: Record<string, unknown> | undefined): string | undefined {
62
+ if (!span) return undefined;
63
+ return `https://www.braintrust.dev/app/test-org/p/pi/logs?oid=${String(span.id)}`;
64
+ }
65
+
66
+ async getSpanPermalink(span: Record<string, unknown> | undefined): Promise<string | undefined> {
67
+ return this.getSpanLink(span);
68
+ }
69
+
70
+ updateSpan(args: Record<string, unknown>): void {
71
+ mockState.updateSpans.push(args);
72
+ }
73
+
74
+ async flush(): Promise<void> {
75
+ mockState.flushCalls += 1;
76
+ }
77
+ }
78
+
79
+ return {
80
+ BraintrustClient: MockBraintrustClient,
81
+ };
82
+ });
83
+
84
+ const ENV_KEYS = [
85
+ "HOME",
86
+ "TRACE_TO_BRAINTRUST",
87
+ "BRAINTRUST_API_KEY",
88
+ "BRAINTRUST_STATE_DIR",
89
+ "BRAINTRUST_PROJECT",
90
+ "BRAINTRUST_ORG_NAME",
91
+ "BRAINTRUST_ADDITIONAL_METADATA",
92
+ ] as const;
93
+
94
+ const originalEnv = Object.fromEntries(ENV_KEYS.map((key) => [key, process.env[key]]));
95
+ const originalProcessCwd = process.cwd();
96
+ const tempDirs: string[] = [];
97
+
98
+ const TEST_API = "pi-extension-test-api" as Api;
99
+
100
+ const TEST_MODEL: Model<Api> = {
101
+ id: "pi-extension-test-model",
102
+ name: "PI Extension Test Model",
103
+ api: TEST_API,
104
+ provider: "pi-extension-test-provider",
105
+ baseUrl: "https://example.invalid",
106
+ reasoning: false,
107
+ input: ["text"],
108
+ cost: {
109
+ input: 0,
110
+ output: 0,
111
+ cacheRead: 0,
112
+ cacheWrite: 0,
113
+ },
114
+ contextWindow: 32_000,
115
+ maxTokens: 4_096,
116
+ };
117
+
118
+ beforeEach(() => {
119
+ mockState.startSpans.length = 0;
120
+ mockState.logSpans.length = 0;
121
+ mockState.endSpans.length = 0;
122
+ mockState.updateSpans.length = 0;
123
+ mockState.initializeCalls = 0;
124
+ mockState.flushCalls = 0;
125
+ mockState.failInitialize = false;
126
+
127
+ for (const key of ENV_KEYS) {
128
+ delete process.env[key];
129
+ }
130
+ });
131
+
132
+ afterEach(() => {
133
+ for (const key of ENV_KEYS) {
134
+ const value = originalEnv[key];
135
+ if (value === undefined) delete process.env[key];
136
+ else process.env[key] = value;
137
+ }
138
+
139
+ process.chdir(originalProcessCwd);
140
+
141
+ while (tempDirs.length > 0) {
142
+ rmSync(tempDirs.pop()!, { recursive: true, force: true });
143
+ }
144
+ });
145
+
146
+ function makeTempDir(prefix: string): string {
147
+ const dir = mkdtempSync(join(tmpdir(), prefix));
148
+ tempDirs.push(dir);
149
+ return dir;
150
+ }
151
+
152
+ function buildAssistantMessage(model: Model<Api>): AssistantMessage {
153
+ return {
154
+ role: "assistant",
155
+ content: [],
156
+ api: model.api,
157
+ provider: model.provider,
158
+ model: model.id,
159
+ usage: {
160
+ input: 0,
161
+ output: 0,
162
+ cacheRead: 0,
163
+ cacheWrite: 0,
164
+ totalTokens: 0,
165
+ cost: {
166
+ input: 0,
167
+ output: 0,
168
+ cacheRead: 0,
169
+ cacheWrite: 0,
170
+ total: 0,
171
+ },
172
+ },
173
+ stopReason: "stop",
174
+ timestamp: Date.now(),
175
+ };
176
+ }
177
+
178
+ function userText(context: Context): string {
179
+ for (let index = context.messages.length - 1; index >= 0; index -= 1) {
180
+ const message = context.messages[index];
181
+ if (!message || message.role !== "user") continue;
182
+ if (typeof message.content === "string") return message.content;
183
+ return message.content
184
+ .filter((part) => part.type === "text")
185
+ .map((part) => part.text)
186
+ .join("\n");
187
+ }
188
+ return "";
189
+ }
190
+
191
+ function hasToolResults(context: Context): boolean {
192
+ return context.messages.some((message) => message.role === "toolResult");
193
+ }
194
+
195
+ function pushText(
196
+ stream: AssistantMessageEventStream,
197
+ output: AssistantMessage,
198
+ text: string,
199
+ ): void {
200
+ output.content.push({ type: "text", text: "" });
201
+ const contentIndex = output.content.length - 1;
202
+ stream.push({ type: "text_start", contentIndex, partial: output });
203
+ const block = output.content[contentIndex];
204
+ if (block?.type === "text") {
205
+ block.text += text;
206
+ }
207
+ stream.push({ type: "text_delta", contentIndex, delta: text, partial: output });
208
+ stream.push({ type: "text_end", contentIndex, content: text, partial: output });
209
+ }
210
+
211
+ function pushToolCall(
212
+ stream: AssistantMessageEventStream,
213
+ output: AssistantMessage,
214
+ toolCall: ToolCall,
215
+ ): void {
216
+ output.content.push(toolCall);
217
+ const contentIndex = output.content.length - 1;
218
+ stream.push({ type: "toolcall_start", contentIndex, partial: output });
219
+ stream.push({
220
+ type: "toolcall_delta",
221
+ contentIndex,
222
+ delta: JSON.stringify(toolCall.arguments),
223
+ partial: output,
224
+ });
225
+ stream.push({ type: "toolcall_end", contentIndex, toolCall, partial: output });
226
+ }
227
+
228
+ function streamTestModel(
229
+ model: Model<Api>,
230
+ context: Context,
231
+ _options?: SimpleStreamOptions,
232
+ ): AssistantMessageEventStream {
233
+ const stream = createAssistantMessageEventStream();
234
+
235
+ queueMicrotask(() => {
236
+ const output = buildAssistantMessage(model);
237
+ stream.push({ type: "start", partial: output });
238
+
239
+ if (hasToolResults(context)) {
240
+ pushText(stream, output, "parallel tools finished");
241
+ output.stopReason = "stop";
242
+ stream.push({ type: "done", reason: "stop", message: output });
243
+ stream.end();
244
+ return;
245
+ }
246
+
247
+ if (userText(context).includes("parallel-tools")) {
248
+ pushToolCall(stream, output, {
249
+ type: "toolCall",
250
+ id: "tool-1",
251
+ name: "bash",
252
+ arguments: {
253
+ command: "sleep 0.05; echo slow",
254
+ },
255
+ });
256
+ pushToolCall(stream, output, {
257
+ type: "toolCall",
258
+ id: "tool-2",
259
+ name: "bash",
260
+ arguments: {
261
+ command: "echo fast",
262
+ },
263
+ });
264
+ output.stopReason = "toolUse";
265
+ stream.push({ type: "done", reason: "toolUse", message: output });
266
+ stream.end();
267
+ return;
268
+ }
269
+
270
+ pushText(stream, output, "plain response");
271
+ output.stopReason = "stop";
272
+ stream.push({ type: "done", reason: "stop", message: output });
273
+ stream.end();
274
+ });
275
+
276
+ return stream;
277
+ }
278
+
279
+ registerApiProvider({
280
+ api: TEST_API,
281
+ stream: streamTestModel,
282
+ streamSimple: streamTestModel,
283
+ });
284
+
285
+ function testHarnessExtension(pi: ExtensionAPI): void {
286
+ pi.registerProvider("pi-extension-test-provider", {
287
+ baseUrl: TEST_MODEL.baseUrl,
288
+ apiKey: "pi-extension-test-key",
289
+ api: TEST_API,
290
+ models: [TEST_MODEL],
291
+ });
292
+
293
+ pi.registerCommand("test-reload", {
294
+ description: "Reload the runtime for integration tests",
295
+ handler: async (_args, ctx) => {
296
+ await ctx.reload();
297
+ return;
298
+ },
299
+ });
300
+ }
301
+
302
+ async function waitForAsyncWork(): Promise<void> {
303
+ await new Promise((resolve) => setTimeout(resolve, 0));
304
+ }
305
+
306
+ interface TestSessionController {
307
+ prompt(text: string): Promise<void>;
308
+ newSession(): Promise<boolean>;
309
+ switchSession(sessionPath: string): Promise<boolean>;
310
+ fork(entryId: string): Promise<{ cancelled: boolean; selectedText: string }>;
311
+ dispose(): Promise<void>;
312
+ readonly sessionFile: string | undefined;
313
+ readonly sessionManager: SessionManager;
314
+ }
315
+
316
+ async function createHarness(options?: {
317
+ rootDir?: string;
318
+ sessionManager?: SessionManager;
319
+ sessionMode?: "inMemory" | "persistent";
320
+ sessionFile?: string;
321
+ sessionsDir?: string;
322
+ }) {
323
+ const home = options?.rootDir ?? makeTempDir("pi-extension-home-");
324
+ const cwd = join(home, "workspace");
325
+ const agentDir = join(home, "agent");
326
+ const stateDir = join(home, "state");
327
+
328
+ mkdirSync(cwd, { recursive: true });
329
+ mkdirSync(agentDir, { recursive: true });
330
+
331
+ process.env.HOME = home;
332
+ process.env.TRACE_TO_BRAINTRUST = "true";
333
+ process.env.BRAINTRUST_API_KEY = "test-key";
334
+ process.env.BRAINTRUST_PROJECT = "pi";
335
+ process.env.BRAINTRUST_ORG_NAME = "test-org";
336
+ process.env.BRAINTRUST_STATE_DIR = stateDir;
337
+
338
+ const sessionManager =
339
+ options?.sessionManager ??
340
+ (options?.sessionFile
341
+ ? SessionManager.open(options.sessionFile, options.sessionsDir)
342
+ : options?.sessionMode === "persistent"
343
+ ? SessionManager.create(cwd, options.sessionsDir)
344
+ : SessionManager.inMemory(cwd));
345
+
346
+ const compat = piCodingAgent as any;
347
+ if (
348
+ typeof compat.createAgentSessionRuntime === "function" &&
349
+ typeof compat.createAgentSessionServices === "function" &&
350
+ typeof compat.createAgentSessionFromServices === "function"
351
+ ) {
352
+ const runtime = await compat.createAgentSessionRuntime(
353
+ async ({
354
+ cwd: runtimeCwd,
355
+ agentDir: runtimeAgentDir,
356
+ sessionManager: runtimeSessionManager,
357
+ sessionStartEvent,
358
+ }: {
359
+ cwd: string;
360
+ agentDir: string;
361
+ sessionManager: SessionManager;
362
+ sessionStartEvent?: unknown;
363
+ }) => {
364
+ const services = await compat.createAgentSessionServices({
365
+ cwd: runtimeCwd,
366
+ agentDir: runtimeAgentDir,
367
+ resourceLoaderOptions: {
368
+ extensionFactories: [testHarnessExtension, braintrustPiExtension],
369
+ },
370
+ });
371
+
372
+ return {
373
+ ...(await compat.createAgentSessionFromServices({
374
+ services,
375
+ sessionManager: runtimeSessionManager,
376
+ sessionStartEvent,
377
+ model: TEST_MODEL,
378
+ })),
379
+ services,
380
+ diagnostics: services.diagnostics,
381
+ };
382
+ },
383
+ {
384
+ cwd,
385
+ agentDir,
386
+ sessionManager,
387
+ },
388
+ );
389
+
390
+ const bindRuntimeSession = async (): Promise<void> => {
391
+ await runtime.session.bindExtensions({});
392
+ };
393
+
394
+ await bindRuntimeSession();
395
+
396
+ const session: TestSessionController = {
397
+ prompt: (text) => runtime.session.prompt(text),
398
+ newSession: async () => {
399
+ const result = await runtime.newSession();
400
+ if (!result.cancelled) {
401
+ await bindRuntimeSession();
402
+ }
403
+ return !result.cancelled;
404
+ },
405
+ switchSession: async (sessionPath) => {
406
+ const result = await runtime.switchSession(sessionPath);
407
+ if (!result.cancelled) {
408
+ await bindRuntimeSession();
409
+ }
410
+ return !result.cancelled;
411
+ },
412
+ fork: async (entryId) => {
413
+ const result = await runtime.fork(entryId);
414
+ if (!result.cancelled) {
415
+ await bindRuntimeSession();
416
+ }
417
+ return {
418
+ cancelled: result.cancelled,
419
+ selectedText: result.selectedText ?? "",
420
+ };
421
+ },
422
+ dispose: async () => {
423
+ await runtime.dispose();
424
+ },
425
+ get sessionFile() {
426
+ return runtime.session.sessionFile;
427
+ },
428
+ get sessionManager() {
429
+ return runtime.session.sessionManager as SessionManager;
430
+ },
431
+ };
432
+
433
+ return { agentDir, cwd, session, stateDir };
434
+ }
435
+
436
+ const resourceLoader = new DefaultResourceLoader({
437
+ cwd,
438
+ agentDir,
439
+ extensionFactories: [testHarnessExtension, braintrustPiExtension],
440
+ });
441
+ await resourceLoader.reload();
442
+
443
+ const { session: legacySession } = await createAgentSession({
444
+ cwd,
445
+ agentDir,
446
+ model: TEST_MODEL,
447
+ resourceLoader,
448
+ sessionManager,
449
+ });
450
+
451
+ const session: TestSessionController = {
452
+ prompt: (text) => legacySession.prompt(text),
453
+ newSession: () => legacySession.newSession(),
454
+ switchSession: (sessionPath) => legacySession.switchSession(sessionPath),
455
+ fork: async (entryId) => {
456
+ const result = await legacySession.fork(entryId);
457
+ return {
458
+ cancelled: result.cancelled,
459
+ selectedText: result.selectedText,
460
+ };
461
+ },
462
+ dispose: async () => {
463
+ legacySession.dispose();
464
+ },
465
+ get sessionFile() {
466
+ return legacySession.sessionFile;
467
+ },
468
+ get sessionManager() {
469
+ return legacySession.sessionManager;
470
+ },
471
+ };
472
+
473
+ return { agentDir, cwd, session, stateDir };
474
+ }
475
+
476
+ function rootTaskSpans(): Array<Record<string, unknown>> {
477
+ return mockState.startSpans.filter(
478
+ (span) => span.type === "task" && span.parentSpanId === undefined,
479
+ );
480
+ }
481
+
482
+ describe("braintrustPiExtension integration", () => {
483
+ it("restores persisted trace state when reopening the same pi session", async () => {
484
+ const rootDir = makeTempDir("pi-extension-home-");
485
+ const sessionsDir = makeTempDir("pi-extension-sessions-");
486
+ const first = await createHarness({
487
+ rootDir,
488
+ sessionMode: "persistent",
489
+ sessionsDir,
490
+ });
491
+ await first.session.prompt("create a traced turn");
492
+ const firstSessionFile = first.session.sessionFile;
493
+
494
+ expect(firstSessionFile).toBeTruthy();
495
+ expect(rootTaskSpans()).toHaveLength(1);
496
+
497
+ await first.session.dispose();
498
+ await waitForAsyncWork();
499
+
500
+ const reopened = await createHarness({
501
+ rootDir,
502
+ sessionFile: firstSessionFile!,
503
+ sessionsDir,
504
+ });
505
+ await reopened.session.prompt("resume the same traced session");
506
+ await reopened.session.dispose();
507
+ await waitForAsyncWork();
508
+
509
+ expect(rootTaskSpans()).toHaveLength(1);
510
+ expect(
511
+ mockState.startSpans.filter(
512
+ (span) => span.type === "task" && span.parentSpanId !== undefined,
513
+ ),
514
+ ).toHaveLength(2);
515
+ });
516
+
517
+ it("keeps one root span across session switch, fork, and resume flows", async () => {
518
+ const sessionsDir = makeTempDir("pi-extension-sessions-");
519
+ const { session } = await createHarness({
520
+ sessionMode: "persistent",
521
+ sessionsDir,
522
+ });
523
+
524
+ await session.prompt("session a");
525
+ const sessionAFile = session.sessionFile;
526
+
527
+ const switched = await session.newSession();
528
+ expect(switched).toBe(true);
529
+ await session.prompt("session b");
530
+ const sessionBFile = session.sessionFile;
531
+
532
+ const forkEntryId = session.sessionManager
533
+ .getBranch()
534
+ .find((entry) => entry.type === "message")?.id;
535
+ expect(forkEntryId).toBeTruthy();
536
+ const forked = await session.fork(forkEntryId!);
537
+ expect(forked.cancelled).toBe(false);
538
+ await session.prompt("session c from fork");
539
+ const sessionCFile = session.sessionFile;
540
+
541
+ expect(sessionAFile).toBeTruthy();
542
+ expect(sessionBFile).toBeTruthy();
543
+ expect(sessionCFile).toBeTruthy();
544
+ expect(sessionCFile).not.toBe(sessionBFile);
545
+
546
+ const resumed = await session.switchSession(sessionAFile!);
547
+ expect(resumed).toBe(true);
548
+ await session.prompt("back on session a");
549
+
550
+ await session.dispose();
551
+ await waitForAsyncWork();
552
+
553
+ expect(rootTaskSpans()).toHaveLength(3);
554
+ expect(
555
+ rootTaskSpans().find(
556
+ (span) =>
557
+ (span.metadata as Record<string, unknown> | undefined)?.opened_via === "session_fork",
558
+ ),
559
+ ).toMatchObject({
560
+ metadata: {
561
+ opened_via: "session_fork",
562
+ parent_session_file: sessionBFile,
563
+ },
564
+ });
565
+ });
566
+
567
+ it("preserves pi's parallel tool end ordering when creating tool spans", async () => {
568
+ const { session } = await createHarness();
569
+
570
+ await session.prompt("parallel-tools");
571
+ await session.dispose();
572
+ await waitForAsyncWork();
573
+
574
+ const llmSpans = mockState.startSpans.filter((span) => span.type === "llm");
575
+ const toolSpans = mockState.startSpans.filter((span) => span.type === "tool");
576
+ const firstLlmSpanId = llmSpans[0]?.spanId;
577
+
578
+ expect(toolSpans).toHaveLength(2);
579
+ expect(
580
+ toolSpans.map((span) => (span.metadata as Record<string, unknown> | undefined)?.tool_call_id),
581
+ ).toEqual(["tool-1", "tool-2"]);
582
+ expect(toolSpans.map((span) => span.parentSpanId)).toEqual([firstLlmSpanId, firstLlmSpanId]);
583
+ });
584
+
585
+ it("stops tracing new work after Braintrust initialization fails", async () => {
586
+ mockState.failInitialize = true;
587
+
588
+ const { session } = await createHarness();
589
+ await waitForAsyncWork();
590
+ await session.prompt("plain-response");
591
+ await session.dispose();
592
+ await waitForAsyncWork();
593
+
594
+ expect(mockState.initializeCalls).toBe(1);
595
+ expect(mockState.startSpans).toEqual([]);
596
+ expect(mockState.flushCalls).toBe(0);
597
+ });
598
+ });