@oh-my-pi/pi-ai 8.1.0 → 8.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +11 -12
- package/package.json +38 -14
- package/src/cli.ts +6 -6
- package/src/providers/amazon-bedrock.ts +12 -13
- package/src/providers/anthropic.ts +25 -26
- package/src/providers/cursor.ts +57 -57
- package/src/providers/google-gemini-cli-usage.ts +2 -2
- package/src/providers/google-gemini-cli.ts +8 -10
- package/src/providers/google-shared.ts +12 -13
- package/src/providers/google-vertex.ts +7 -7
- package/src/providers/google.ts +8 -8
- package/src/providers/openai-codex/request-transformer.ts +6 -6
- package/src/providers/openai-codex-responses.ts +28 -28
- package/src/providers/openai-completions.ts +39 -39
- package/src/providers/openai-responses.ts +31 -31
- package/src/providers/transform-messages.ts +3 -3
- package/src/storage.ts +29 -19
- package/src/stream.ts +6 -6
- package/src/types.ts +1 -2
- package/src/usage/claude.ts +4 -4
- package/src/usage/github-copilot.ts +3 -4
- package/src/usage/google-antigravity.ts +3 -3
- package/src/usage/openai-codex.ts +4 -4
- package/src/usage/zai.ts +3 -3
- package/src/usage.ts +0 -1
- package/src/utils/event-stream.ts +4 -4
- package/src/utils/oauth/anthropic.ts +0 -1
- package/src/utils/oauth/callback-server.ts +2 -3
- package/src/utils/oauth/github-copilot.ts +2 -3
- package/src/utils/oauth/google-antigravity.ts +0 -1
- package/src/utils/oauth/google-gemini-cli.ts +2 -3
- package/src/utils/oauth/index.ts +11 -12
- package/src/utils/oauth/openai-codex.ts +0 -1
- package/src/utils/overflow.ts +2 -2
- package/src/utils/validation.ts +4 -5
package/src/providers/cursor.ts
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import { createHash } from "node:crypto";
|
|
2
|
-
import
|
|
2
|
+
import * as fs from "node:fs/promises";
|
|
3
3
|
import http2 from "node:http2";
|
|
4
4
|
import { create, fromBinary, fromJson, type JsonValue, toBinary, toJson } from "@bufbuild/protobuf";
|
|
5
5
|
import { ValueSchema } from "@bufbuild/protobuf/wkt";
|
|
6
|
-
import
|
|
6
|
+
import JSON5 from "json5";
|
|
7
|
+
import { calculateCost } from "../models";
|
|
7
8
|
import type {
|
|
8
9
|
Api,
|
|
9
10
|
AssistantMessage,
|
|
@@ -22,11 +23,10 @@ import type {
|
|
|
22
23
|
Tool,
|
|
23
24
|
ToolCall,
|
|
24
25
|
ToolResultMessage,
|
|
25
|
-
} from "
|
|
26
|
-
import { AssistantMessageEventStream } from "
|
|
27
|
-
import { parseStreamingJson } from "
|
|
28
|
-
import { formatErrorMessageWithRetryAfter } from "
|
|
29
|
-
import JSON5 from "json5";
|
|
26
|
+
} from "../types";
|
|
27
|
+
import { AssistantMessageEventStream } from "../utils/event-stream";
|
|
28
|
+
import { parseStreamingJson } from "../utils/json-parse";
|
|
29
|
+
import { formatErrorMessageWithRetryAfter } from "../utils/retry-after";
|
|
30
30
|
import type { McpToolDefinition } from "./cursor/gen/agent_pb";
|
|
31
31
|
import {
|
|
32
32
|
AgentClientMessageSchema,
|
|
@@ -142,7 +142,7 @@ async function appendCursorDebugLog(entry: CursorLogEntry): Promise<void> {
|
|
|
142
142
|
const logPath = process.env.DEBUG_CURSOR_LOG;
|
|
143
143
|
if (!logPath) return;
|
|
144
144
|
try {
|
|
145
|
-
await appendFile(logPath, `${JSON.stringify(entry, debugReplacer)}\n`);
|
|
145
|
+
await fs.appendFile(logPath, `${JSON.stringify(entry, debugReplacer)}\n`);
|
|
146
146
|
} catch {
|
|
147
147
|
// Ignore debug log failures
|
|
148
148
|
}
|
|
@@ -245,7 +245,7 @@ function decodeLogData(value: unknown): unknown {
|
|
|
245
245
|
return value;
|
|
246
246
|
}
|
|
247
247
|
if (Array.isArray(value)) {
|
|
248
|
-
return value.map(
|
|
248
|
+
return value.map(entry => decodeLogData(entry));
|
|
249
249
|
}
|
|
250
250
|
const record = value as Record<string, unknown>;
|
|
251
251
|
const typeName = record.$typeName;
|
|
@@ -375,13 +375,13 @@ export const streamCursor: StreamFunction<"cursor-agent"> = (
|
|
|
375
375
|
get firstTokenTime() {
|
|
376
376
|
return firstTokenTime;
|
|
377
377
|
},
|
|
378
|
-
setTextBlock:
|
|
378
|
+
setTextBlock: b => {
|
|
379
379
|
currentTextBlock = b;
|
|
380
380
|
},
|
|
381
|
-
setThinkingBlock:
|
|
381
|
+
setThinkingBlock: b => {
|
|
382
382
|
currentThinkingBlock = b;
|
|
383
383
|
},
|
|
384
|
-
setToolCall:
|
|
384
|
+
setToolCall: t => {
|
|
385
385
|
currentToolCall = t;
|
|
386
386
|
},
|
|
387
387
|
setFirstTokenTime: () => {
|
|
@@ -427,7 +427,7 @@ export const streamCursor: StreamFunction<"cursor-agent"> = (
|
|
|
427
427
|
usageState,
|
|
428
428
|
requestContextTools,
|
|
429
429
|
onConversationCheckpoint,
|
|
430
|
-
).catch(
|
|
430
|
+
).catch(error => {
|
|
431
431
|
log("error", "handleServerMessage", { error: String(error) });
|
|
432
432
|
});
|
|
433
433
|
} catch (e) {
|
|
@@ -452,7 +452,7 @@ export const streamCursor: StreamFunction<"cursor-agent"> = (
|
|
|
452
452
|
heartbeatTimer = setInterval(sendHeartbeat, 5000);
|
|
453
453
|
|
|
454
454
|
await new Promise<void>((resolve, reject) => {
|
|
455
|
-
h2Request!.on("trailers",
|
|
455
|
+
h2Request!.on("trailers", trailers => {
|
|
456
456
|
const status = trailers["grpc-status"];
|
|
457
457
|
const msg = trailers["grpc-message"];
|
|
458
458
|
if (status && status !== "0") {
|
|
@@ -662,9 +662,9 @@ async function handleShellStreamArgs(
|
|
|
662
662
|
args as any,
|
|
663
663
|
execHandlers?.shell,
|
|
664
664
|
onToolResult,
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
665
|
+
toolResult => buildShellResultFromToolResult(args as any, toolResult),
|
|
666
|
+
reason => buildShellRejectedResult((args as any).command, (args as any).workingDirectory, reason),
|
|
667
|
+
error => buildShellFailureResult((args as any).command, (args as any).workingDirectory, error),
|
|
668
668
|
);
|
|
669
669
|
|
|
670
670
|
sendShellStreamEvent(h2Request, execMsg, { case: "start", value: create(ShellStreamStartSchema, {}) });
|
|
@@ -810,9 +810,9 @@ async function handleExecServerMessage(
|
|
|
810
810
|
args,
|
|
811
811
|
execHandlers?.read,
|
|
812
812
|
onToolResult,
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
813
|
+
toolResult => buildReadResultFromToolResult(args.path, toolResult),
|
|
814
|
+
reason => buildReadRejectedResult(args.path, reason),
|
|
815
|
+
error => buildReadErrorResult(args.path, error),
|
|
816
816
|
);
|
|
817
817
|
sendExecClientMessage(h2Request, execMsg, "readResult", execResult);
|
|
818
818
|
return;
|
|
@@ -823,9 +823,9 @@ async function handleExecServerMessage(
|
|
|
823
823
|
args,
|
|
824
824
|
execHandlers?.ls,
|
|
825
825
|
onToolResult,
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
826
|
+
toolResult => buildLsResultFromToolResult(args.path, toolResult),
|
|
827
|
+
reason => buildLsRejectedResult(args.path, reason),
|
|
828
|
+
error => buildLsErrorResult(args.path, error),
|
|
829
829
|
);
|
|
830
830
|
sendExecClientMessage(h2Request, execMsg, "lsResult", execResult);
|
|
831
831
|
return;
|
|
@@ -836,9 +836,9 @@ async function handleExecServerMessage(
|
|
|
836
836
|
args,
|
|
837
837
|
execHandlers?.grep,
|
|
838
838
|
onToolResult,
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
839
|
+
toolResult => buildGrepResultFromToolResult(args, toolResult),
|
|
840
|
+
reason => buildGrepErrorResult(reason),
|
|
841
|
+
error => buildGrepErrorResult(error),
|
|
842
842
|
);
|
|
843
843
|
sendExecClientMessage(h2Request, execMsg, "grepResult", execResult);
|
|
844
844
|
return;
|
|
@@ -849,7 +849,7 @@ async function handleExecServerMessage(
|
|
|
849
849
|
args,
|
|
850
850
|
execHandlers?.write,
|
|
851
851
|
onToolResult,
|
|
852
|
-
|
|
852
|
+
toolResult =>
|
|
853
853
|
buildWriteResultFromToolResult(
|
|
854
854
|
{
|
|
855
855
|
path: args.path,
|
|
@@ -859,8 +859,8 @@ async function handleExecServerMessage(
|
|
|
859
859
|
},
|
|
860
860
|
toolResult,
|
|
861
861
|
),
|
|
862
|
-
|
|
863
|
-
|
|
862
|
+
reason => buildWriteRejectedResult(args.path, reason),
|
|
863
|
+
error => buildWriteErrorResult(args.path, error),
|
|
864
864
|
);
|
|
865
865
|
sendExecClientMessage(h2Request, execMsg, "writeResult", execResult);
|
|
866
866
|
return;
|
|
@@ -871,9 +871,9 @@ async function handleExecServerMessage(
|
|
|
871
871
|
args,
|
|
872
872
|
execHandlers?.delete,
|
|
873
873
|
onToolResult,
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
874
|
+
toolResult => buildDeleteResultFromToolResult(args.path, toolResult),
|
|
875
|
+
reason => buildDeleteRejectedResult(args.path, reason),
|
|
876
|
+
error => buildDeleteErrorResult(args.path, error),
|
|
877
877
|
);
|
|
878
878
|
sendExecClientMessage(h2Request, execMsg, "deleteResult", execResult);
|
|
879
879
|
return;
|
|
@@ -884,9 +884,9 @@ async function handleExecServerMessage(
|
|
|
884
884
|
args,
|
|
885
885
|
execHandlers?.shell,
|
|
886
886
|
onToolResult,
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
887
|
+
toolResult => buildShellResultFromToolResult(args, toolResult),
|
|
888
|
+
reason => buildShellRejectedResult(args.command, args.workingDirectory, reason),
|
|
889
|
+
error => buildShellFailureResult(args.command, args.workingDirectory, error),
|
|
890
890
|
);
|
|
891
891
|
sendExecClientMessage(h2Request, execMsg, "shellResult", execResult);
|
|
892
892
|
return;
|
|
@@ -944,9 +944,9 @@ async function handleExecServerMessage(
|
|
|
944
944
|
args,
|
|
945
945
|
execHandlers?.diagnostics,
|
|
946
946
|
onToolResult,
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
947
|
+
toolResult => buildDiagnosticsResultFromToolResult(args.path, toolResult),
|
|
948
|
+
reason => buildDiagnosticsRejectedResult(args.path, reason),
|
|
949
|
+
error => buildDiagnosticsErrorResult(args.path, error),
|
|
950
950
|
);
|
|
951
951
|
sendExecClientMessage(h2Request, execMsg, "diagnosticsResult", execResult);
|
|
952
952
|
return;
|
|
@@ -958,9 +958,9 @@ async function handleExecServerMessage(
|
|
|
958
958
|
mcpCall,
|
|
959
959
|
execHandlers?.mcp,
|
|
960
960
|
onToolResult,
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
961
|
+
toolResult => buildMcpResultFromToolResult(mcpCall, toolResult),
|
|
962
|
+
_reason => buildMcpToolNotFoundResult(mcpCall),
|
|
963
|
+
error => buildMcpErrorResult(error),
|
|
964
964
|
);
|
|
965
965
|
sendExecClientMessage(h2Request, execMsg, "mcpResult", execResult);
|
|
966
966
|
return;
|
|
@@ -1075,7 +1075,7 @@ async function applyToolResultHandler(
|
|
|
1075
1075
|
}
|
|
1076
1076
|
|
|
1077
1077
|
function toolResultToText(toolResult: ToolResultMessage): string {
|
|
1078
|
-
return toolResult.content.map(
|
|
1078
|
+
return toolResult.content.map(item => (item.type === "text" ? item.text : `[${item.mimeType} image]`)).join("\n");
|
|
1079
1079
|
}
|
|
1080
1080
|
|
|
1081
1081
|
function toolResultWasTruncated(toolResult: ToolResultMessage): boolean {
|
|
@@ -1274,8 +1274,8 @@ function buildLsResultFromToolResult(path: string, toolResult: ToolResultMessage
|
|
|
1274
1274
|
const rootPath = path || ".";
|
|
1275
1275
|
const entries = text
|
|
1276
1276
|
.split("\n")
|
|
1277
|
-
.map(
|
|
1278
|
-
.filter(
|
|
1277
|
+
.map(line => line.trim())
|
|
1278
|
+
.filter(line => line.length > 0 && !line.startsWith("["));
|
|
1279
1279
|
const childrenDirs: LsDirectoryTreeNode[] = [];
|
|
1280
1280
|
const childrenFiles: LsDirectoryTreeNode_File[] = [];
|
|
1281
1281
|
|
|
@@ -1346,8 +1346,8 @@ function buildGrepResultFromToolResult(
|
|
|
1346
1346
|
const clientTruncated = toolResultDetailBoolean(toolResult, "truncated");
|
|
1347
1347
|
const lines = text
|
|
1348
1348
|
.split("\n")
|
|
1349
|
-
.map(
|
|
1350
|
-
.filter(
|
|
1349
|
+
.map(line => line.trimEnd())
|
|
1350
|
+
.filter(line => line.length > 0 && !line.startsWith("[") && !line.toLowerCase().startsWith("no matches"));
|
|
1351
1351
|
|
|
1352
1352
|
const workspaceKey = args.path || ".";
|
|
1353
1353
|
let unionResult: GrepUnionResult;
|
|
@@ -1367,7 +1367,7 @@ function buildGrepResultFromToolResult(
|
|
|
1367
1367
|
});
|
|
1368
1368
|
} else if (outputMode === "count") {
|
|
1369
1369
|
const counts = lines
|
|
1370
|
-
.map(
|
|
1370
|
+
.map(line => {
|
|
1371
1371
|
const separatorIndex = line.lastIndexOf(":");
|
|
1372
1372
|
if (separatorIndex === -1) {
|
|
1373
1373
|
return null;
|
|
@@ -1417,7 +1417,7 @@ function buildGrepResultFromToolResult(
|
|
|
1417
1417
|
const matches = Array.from(matchMap.entries()).map(([file, matches]) =>
|
|
1418
1418
|
create(GrepFileMatchSchema, {
|
|
1419
1419
|
file,
|
|
1420
|
-
matches: matches.map(
|
|
1420
|
+
matches: matches.map(entry =>
|
|
1421
1421
|
create(GrepContentMatchSchema, {
|
|
1422
1422
|
lineNumber: entry.line,
|
|
1423
1423
|
content: entry.content,
|
|
@@ -1586,7 +1586,7 @@ function buildTodoWriteArgs(toolCall: CursorUpdateTodosToolCall): {
|
|
|
1586
1586
|
const todos = toolCall.updateTodosToolCall?.args?.todos;
|
|
1587
1587
|
if (!todos) return null;
|
|
1588
1588
|
return {
|
|
1589
|
-
todos: todos.map(
|
|
1589
|
+
todos: todos.map(todo => ({
|
|
1590
1590
|
id: typeof todo.id === "string" && todo.id.length > 0 ? todo.id : undefined,
|
|
1591
1591
|
content: typeof todo.content === "string" ? todo.content : "",
|
|
1592
1592
|
activeForm: typeof todo.content === "string" ? todo.content : "",
|
|
@@ -1599,7 +1599,7 @@ function buildMcpResultFromToolResult(_mcpCall: CursorMcpCall, toolResult: ToolR
|
|
|
1599
1599
|
if (toolResult.isError) {
|
|
1600
1600
|
return buildMcpErrorResult(toolResultToText(toolResult) || "MCP tool failed");
|
|
1601
1601
|
}
|
|
1602
|
-
const content = toolResult.content.map(
|
|
1602
|
+
const content = toolResult.content.map(item => {
|
|
1603
1603
|
if (item.type === "image") {
|
|
1604
1604
|
return create(McpToolResultContentItemSchema, {
|
|
1605
1605
|
content: {
|
|
@@ -1810,12 +1810,12 @@ function buildMcpToolDefinitions(tools: Tool[] | undefined): McpToolDefinition[]
|
|
|
1810
1810
|
return [];
|
|
1811
1811
|
}
|
|
1812
1812
|
|
|
1813
|
-
const advertisedTools = tools.filter(
|
|
1813
|
+
const advertisedTools = tools.filter(tool => !CURSOR_NATIVE_TOOL_NAMES.has(tool.name));
|
|
1814
1814
|
if (advertisedTools.length === 0) {
|
|
1815
1815
|
return [];
|
|
1816
1816
|
}
|
|
1817
1817
|
|
|
1818
|
-
return advertisedTools.map(
|
|
1818
|
+
return advertisedTools.map(tool => {
|
|
1819
1819
|
const jsonSchema = tool.parameters as Record<string, unknown> | undefined;
|
|
1820
1820
|
const schemaValue: JsonValue =
|
|
1821
1821
|
jsonSchema && typeof jsonSchema === "object"
|
|
@@ -1841,7 +1841,7 @@ function extractUserMessageText(msg: Message): string {
|
|
|
1841
1841
|
if (typeof content === "string") return content.trim();
|
|
1842
1842
|
const text = content
|
|
1843
1843
|
.filter((c): c is TextContent => c.type === "text")
|
|
1844
|
-
.map(
|
|
1844
|
+
.map(c => c.text)
|
|
1845
1845
|
.join("\n");
|
|
1846
1846
|
return text.trim();
|
|
1847
1847
|
}
|
|
@@ -1854,7 +1854,7 @@ function extractAssistantMessageText(msg: Message): string {
|
|
|
1854
1854
|
if (!Array.isArray(msg.content)) return "";
|
|
1855
1855
|
return msg.content
|
|
1856
1856
|
.filter((c): c is TextContent => c.type === "text")
|
|
1857
|
-
.map(
|
|
1857
|
+
.map(c => c.text)
|
|
1858
1858
|
.join("\n");
|
|
1859
1859
|
}
|
|
1860
1860
|
|
|
@@ -2007,7 +2007,7 @@ function buildGrpcRequest(
|
|
|
2007
2007
|
// Build conversation turns from prior messages (excluding the last user message)
|
|
2008
2008
|
const turns = buildConversationTurns(context.messages);
|
|
2009
2009
|
|
|
2010
|
-
const hasMatchingPrompt = state.conversationState?.rootPromptMessagesJson?.some(
|
|
2010
|
+
const hasMatchingPrompt = state.conversationState?.rootPromptMessagesJson?.some(entry =>
|
|
2011
2011
|
Buffer.from(entry).equals(systemPromptId),
|
|
2012
2012
|
);
|
|
2013
2013
|
|
|
@@ -2064,7 +2064,7 @@ function buildGrpcRequest(
|
|
|
2064
2064
|
|
|
2065
2065
|
const requestBytes = toBinary(AgentClientMessageSchema, clientMessage);
|
|
2066
2066
|
|
|
2067
|
-
const toolNames = context.tools?.map(
|
|
2067
|
+
const toolNames = context.tools?.map(tool => tool.name) ?? [];
|
|
2068
2068
|
const detail =
|
|
2069
2069
|
process.env.DEBUG_CURSOR === "2"
|
|
2070
2070
|
? ` ${JSON.stringify(clientMessage.message.value, debugReplacer, 2)?.slice(0, 2000)}`
|
|
@@ -2082,6 +2082,6 @@ function buildGrpcRequest(
|
|
|
2082
2082
|
function extractText(content: (TextContent | ImageContent)[]): string {
|
|
2083
2083
|
return content
|
|
2084
2084
|
.filter((c): c is TextContent => c.type === "text")
|
|
2085
|
-
.map(
|
|
2085
|
+
.map(c => c.text)
|
|
2086
2086
|
.join("\n");
|
|
2087
2087
|
}
|
|
@@ -6,8 +6,8 @@ import type {
|
|
|
6
6
|
UsageProvider,
|
|
7
7
|
UsageReport,
|
|
8
8
|
UsageWindow,
|
|
9
|
-
} from "
|
|
10
|
-
import { refreshGoogleCloudToken } from "
|
|
9
|
+
} from "../usage";
|
|
10
|
+
import { refreshGoogleCloudToken } from "../utils/oauth/google-gemini-cli";
|
|
11
11
|
|
|
12
12
|
const DEFAULT_ENDPOINT = "https://cloudcode-pa.googleapis.com";
|
|
13
13
|
const CACHE_TTL_MS = 60_000;
|
|
@@ -3,10 +3,10 @@
|
|
|
3
3
|
* Shared implementation for both google-gemini-cli and google-antigravity providers.
|
|
4
4
|
* Uses the Cloud Code Assist API endpoint to access Gemini and Claude models.
|
|
5
5
|
*/
|
|
6
|
-
|
|
7
6
|
import { createHash } from "node:crypto";
|
|
8
7
|
import type { Content, ThinkingConfig } from "@google/genai";
|
|
9
|
-
import {
|
|
8
|
+
import { abortableSleep } from "@oh-my-pi/pi-utils";
|
|
9
|
+
import { calculateCost } from "../models";
|
|
10
10
|
import type {
|
|
11
11
|
Api,
|
|
12
12
|
AssistantMessage,
|
|
@@ -17,10 +17,9 @@ import type {
|
|
|
17
17
|
TextContent,
|
|
18
18
|
ThinkingContent,
|
|
19
19
|
ToolCall,
|
|
20
|
-
} from "
|
|
21
|
-
import { AssistantMessageEventStream } from "
|
|
22
|
-
import { sanitizeSurrogates } from "
|
|
23
|
-
import { abortableSleep } from "@oh-my-pi/pi-utils";
|
|
20
|
+
} from "../types";
|
|
21
|
+
import { AssistantMessageEventStream } from "../utils/event-stream";
|
|
22
|
+
import { sanitizeSurrogates } from "../utils/sanitize-unicode";
|
|
24
23
|
import {
|
|
25
24
|
convertMessages,
|
|
26
25
|
convertTools,
|
|
@@ -660,8 +659,7 @@ export const streamGoogleGeminiCli: StreamFunction<"google-gemini-cli"> = (
|
|
|
660
659
|
|
|
661
660
|
const providedId = part.functionCall.id;
|
|
662
661
|
const needsNewId =
|
|
663
|
-
!providedId ||
|
|
664
|
-
output.content.some((b) => b.type === "toolCall" && b.id === providedId);
|
|
662
|
+
!providedId || output.content.some(b => b.type === "toolCall" && b.id === providedId);
|
|
665
663
|
const toolCallId = needsNewId
|
|
666
664
|
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
|
|
667
665
|
: providedId;
|
|
@@ -695,7 +693,7 @@ export const streamGoogleGeminiCli: StreamFunction<"google-gemini-cli"> = (
|
|
|
695
693
|
|
|
696
694
|
if (candidate?.finishReason) {
|
|
697
695
|
output.stopReason = mapStopReasonString(candidate.finishReason);
|
|
698
|
-
if (output.content.some(
|
|
696
|
+
if (output.content.some(b => b.type === "toolCall")) {
|
|
699
697
|
output.stopReason = "toolUse";
|
|
700
698
|
}
|
|
701
699
|
}
|
|
@@ -840,7 +838,7 @@ function deriveSessionId(context: Context): string | undefined {
|
|
|
840
838
|
} else if (Array.isArray(message.content)) {
|
|
841
839
|
text = message.content
|
|
842
840
|
.filter((item): item is TextContent => item.type === "text")
|
|
843
|
-
.map(
|
|
841
|
+
.map(item => item.text)
|
|
844
842
|
.join("\n");
|
|
845
843
|
}
|
|
846
844
|
|
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
/**
|
|
2
2
|
* Shared utilities for Google Generative AI and Google Cloud Code Assist providers.
|
|
3
3
|
*/
|
|
4
|
-
|
|
5
4
|
import { type Content, FinishReason, FunctionCallingConfigMode, type Part, type Schema } from "@google/genai";
|
|
6
|
-
import type { Context, ImageContent, Model, StopReason, TextContent, Tool } from "
|
|
7
|
-
import { sanitizeSurrogates } from "
|
|
5
|
+
import type { Context, ImageContent, Model, StopReason, TextContent, Tool } from "../types";
|
|
6
|
+
import { sanitizeSurrogates } from "../utils/sanitize-unicode";
|
|
8
7
|
import { transformMessages } from "./transform-messages";
|
|
9
8
|
|
|
10
9
|
type GoogleApiType = "google-generative-ai" | "google-gemini-cli" | "google-vertex";
|
|
@@ -86,7 +85,7 @@ export function convertMessages<T extends GoogleApiType>(model: Model<T>, contex
|
|
|
86
85
|
parts: [{ text: sanitizeSurrogates(msg.content) }],
|
|
87
86
|
});
|
|
88
87
|
} else {
|
|
89
|
-
const parts: Part[] = msg.content.map(
|
|
88
|
+
const parts: Part[] = msg.content.map(item => {
|
|
90
89
|
if (item.type === "text") {
|
|
91
90
|
return { text: sanitizeSurrogates(item.text) };
|
|
92
91
|
} else {
|
|
@@ -99,8 +98,8 @@ export function convertMessages<T extends GoogleApiType>(model: Model<T>, contex
|
|
|
99
98
|
}
|
|
100
99
|
});
|
|
101
100
|
// Filter out images if model doesn't support them, and empty text blocks
|
|
102
|
-
let filteredParts = !model.input.includes("image") ? parts.filter(
|
|
103
|
-
filteredParts = filteredParts.filter(
|
|
101
|
+
let filteredParts = !model.input.includes("image") ? parts.filter(p => p.text !== undefined) : parts;
|
|
102
|
+
filteredParts = filteredParts.filter(p => {
|
|
104
103
|
if (p.text !== undefined) {
|
|
105
104
|
return p.text.trim().length > 0;
|
|
106
105
|
}
|
|
@@ -180,7 +179,7 @@ export function convertMessages<T extends GoogleApiType>(model: Model<T>, contex
|
|
|
180
179
|
} else if (msg.role === "toolResult") {
|
|
181
180
|
// Extract text and image content
|
|
182
181
|
const textContent = msg.content.filter((c): c is TextContent => c.type === "text");
|
|
183
|
-
const textResult = textContent.map(
|
|
182
|
+
const textResult = textContent.map(c => c.text).join("\n");
|
|
184
183
|
const imageContent = model.input.includes("image")
|
|
185
184
|
? msg.content.filter((c): c is ImageContent => c.type === "image")
|
|
186
185
|
: [];
|
|
@@ -196,7 +195,7 @@ export function convertMessages<T extends GoogleApiType>(model: Model<T>, contex
|
|
|
196
195
|
// Use "output" key for success, "error" key for errors as per SDK documentation
|
|
197
196
|
const responseValue = hasText ? sanitizeSurrogates(textResult) : hasImages ? "(see attached image)" : "";
|
|
198
197
|
|
|
199
|
-
const imageParts: Part[] = imageContent.map(
|
|
198
|
+
const imageParts: Part[] = imageContent.map(imageBlock => ({
|
|
200
199
|
inlineData: {
|
|
201
200
|
mimeType: imageBlock.mimeType,
|
|
202
201
|
data: imageBlock.data,
|
|
@@ -221,7 +220,7 @@ export function convertMessages<T extends GoogleApiType>(model: Model<T>, contex
|
|
|
221
220
|
// Cloud Code Assist API requires all function responses to be in a single user turn.
|
|
222
221
|
// Check if the last content is already a user turn with function responses and merge.
|
|
223
222
|
const lastContent = contents[contents.length - 1];
|
|
224
|
-
if (lastContent?.role === "user" && lastContent.parts?.some(
|
|
223
|
+
if (lastContent?.role === "user" && lastContent.parts?.some(p => p.functionResponse)) {
|
|
225
224
|
lastContent.parts.push(functionResponsePart);
|
|
226
225
|
} else {
|
|
227
226
|
contents.push({
|
|
@@ -270,7 +269,7 @@ const UNSUPPORTED_SCHEMA_FIELDS = new Set([
|
|
|
270
269
|
|
|
271
270
|
function sanitizeSchemaImpl(value: unknown, isInsideProperties: boolean): unknown {
|
|
272
271
|
if (Array.isArray(value)) {
|
|
273
|
-
return value.map(
|
|
272
|
+
return value.map(entry => sanitizeSchemaImpl(entry, isInsideProperties));
|
|
274
273
|
}
|
|
275
274
|
|
|
276
275
|
if (!value || typeof value !== "object") {
|
|
@@ -286,11 +285,11 @@ function sanitizeSchemaImpl(value: unknown, isInsideProperties: boolean): unknow
|
|
|
286
285
|
const variants = obj[combiner] as Record<string, unknown>[];
|
|
287
286
|
|
|
288
287
|
// Check if ALL variants have a const field
|
|
289
|
-
const allHaveConst = variants.every(
|
|
288
|
+
const allHaveConst = variants.every(v => v && typeof v === "object" && "const" in v);
|
|
290
289
|
|
|
291
290
|
if (allHaveConst && variants.length > 0) {
|
|
292
291
|
// Extract all const values into enum
|
|
293
|
-
result.enum = variants.map(
|
|
292
|
+
result.enum = variants.map(v => v.const);
|
|
294
293
|
|
|
295
294
|
// Inherit type from first variant if present
|
|
296
295
|
const firstType = variants[0]?.type;
|
|
@@ -327,7 +326,7 @@ function sanitizeSchemaImpl(value: unknown, isInsideProperties: boolean): unknow
|
|
|
327
326
|
if (constValue !== undefined) {
|
|
328
327
|
// Convert const to enum, merging with existing enum if present
|
|
329
328
|
const existingEnum = Array.isArray(result.enum) ? result.enum : [];
|
|
330
|
-
if (!existingEnum.some(
|
|
329
|
+
if (!existingEnum.some(item => Object.is(item, constValue))) {
|
|
331
330
|
existingEnum.push(constValue);
|
|
332
331
|
}
|
|
333
332
|
result.enum = existingEnum;
|
|
@@ -5,7 +5,7 @@ import {
|
|
|
5
5
|
type ThinkingConfig,
|
|
6
6
|
ThinkingLevel,
|
|
7
7
|
} from "@google/genai";
|
|
8
|
-
import { calculateCost } from "
|
|
8
|
+
import { calculateCost } from "../models";
|
|
9
9
|
import type {
|
|
10
10
|
Api,
|
|
11
11
|
AssistantMessage,
|
|
@@ -16,10 +16,10 @@ import type {
|
|
|
16
16
|
TextContent,
|
|
17
17
|
ThinkingContent,
|
|
18
18
|
ToolCall,
|
|
19
|
-
} from "
|
|
20
|
-
import { AssistantMessageEventStream } from "
|
|
21
|
-
import { formatErrorMessageWithRetryAfter } from "
|
|
22
|
-
import { sanitizeSurrogates } from "
|
|
19
|
+
} from "../types";
|
|
20
|
+
import { AssistantMessageEventStream } from "../utils/event-stream";
|
|
21
|
+
import { formatErrorMessageWithRetryAfter } from "../utils/retry-after";
|
|
22
|
+
import { sanitizeSurrogates } from "../utils/sanitize-unicode";
|
|
23
23
|
import type { GoogleThinkingLevel } from "./google-gemini-cli";
|
|
24
24
|
import {
|
|
25
25
|
convertMessages,
|
|
@@ -183,7 +183,7 @@ export const streamGoogleVertex: StreamFunction<"google-vertex"> = (
|
|
|
183
183
|
|
|
184
184
|
const providedId = part.functionCall.id;
|
|
185
185
|
const needsNewId =
|
|
186
|
-
!providedId || output.content.some(
|
|
186
|
+
!providedId || output.content.some(b => b.type === "toolCall" && b.id === providedId);
|
|
187
187
|
const toolCallId = needsNewId
|
|
188
188
|
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
|
|
189
189
|
: providedId;
|
|
@@ -211,7 +211,7 @@ export const streamGoogleVertex: StreamFunction<"google-vertex"> = (
|
|
|
211
211
|
|
|
212
212
|
if (candidate?.finishReason) {
|
|
213
213
|
output.stopReason = mapStopReason(candidate.finishReason);
|
|
214
|
-
if (output.content.some(
|
|
214
|
+
if (output.content.some(b => b.type === "toolCall")) {
|
|
215
215
|
output.stopReason = "toolUse";
|
|
216
216
|
}
|
|
217
217
|
}
|
package/src/providers/google.ts
CHANGED
|
@@ -4,8 +4,8 @@ import {
|
|
|
4
4
|
GoogleGenAI,
|
|
5
5
|
type ThinkingConfig,
|
|
6
6
|
} from "@google/genai";
|
|
7
|
-
import { calculateCost } from "
|
|
8
|
-
import { getEnvApiKey } from "
|
|
7
|
+
import { calculateCost } from "../models";
|
|
8
|
+
import { getEnvApiKey } from "../stream";
|
|
9
9
|
import type {
|
|
10
10
|
Api,
|
|
11
11
|
AssistantMessage,
|
|
@@ -16,10 +16,10 @@ import type {
|
|
|
16
16
|
TextContent,
|
|
17
17
|
ThinkingContent,
|
|
18
18
|
ToolCall,
|
|
19
|
-
} from "
|
|
20
|
-
import { AssistantMessageEventStream } from "
|
|
21
|
-
import { formatErrorMessageWithRetryAfter } from "
|
|
22
|
-
import { sanitizeSurrogates } from "
|
|
19
|
+
} from "../types";
|
|
20
|
+
import { AssistantMessageEventStream } from "../utils/event-stream";
|
|
21
|
+
import { formatErrorMessageWithRetryAfter } from "../utils/retry-after";
|
|
22
|
+
import { sanitizeSurrogates } from "../utils/sanitize-unicode";
|
|
23
23
|
import type { GoogleThinkingLevel } from "./google-gemini-cli";
|
|
24
24
|
import {
|
|
25
25
|
convertMessages,
|
|
@@ -170,7 +170,7 @@ export const streamGoogle: StreamFunction<"google-generative-ai"> = (
|
|
|
170
170
|
// Generate unique ID if not provided or if it's a duplicate
|
|
171
171
|
const providedId = part.functionCall.id;
|
|
172
172
|
const needsNewId =
|
|
173
|
-
!providedId || output.content.some(
|
|
173
|
+
!providedId || output.content.some(b => b.type === "toolCall" && b.id === providedId);
|
|
174
174
|
const toolCallId = needsNewId
|
|
175
175
|
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
|
|
176
176
|
: providedId;
|
|
@@ -198,7 +198,7 @@ export const streamGoogle: StreamFunction<"google-generative-ai"> = (
|
|
|
198
198
|
|
|
199
199
|
if (candidate?.finishReason) {
|
|
200
200
|
output.stopReason = mapStopReason(candidate.finishReason);
|
|
201
|
-
if (output.content.some(
|
|
201
|
+
if (output.content.some(b => b.type === "toolCall")) {
|
|
202
202
|
output.stopReason = "toolUse";
|
|
203
203
|
}
|
|
204
204
|
}
|
|
@@ -73,8 +73,8 @@ function filterInput(input: InputItem[] | undefined): InputItem[] | undefined {
|
|
|
73
73
|
if (!Array.isArray(input)) return input;
|
|
74
74
|
|
|
75
75
|
return input
|
|
76
|
-
.filter(
|
|
77
|
-
.map(
|
|
76
|
+
.filter(item => item.type !== "item_reference")
|
|
77
|
+
.map(item => {
|
|
78
78
|
if (item.id != null) {
|
|
79
79
|
const { id: _id, ...rest } = item;
|
|
80
80
|
return rest as InputItem;
|
|
@@ -97,11 +97,11 @@ export async function transformRequestBody(
|
|
|
97
97
|
if (body.input) {
|
|
98
98
|
const functionCallIds = new Set(
|
|
99
99
|
body.input
|
|
100
|
-
.filter(
|
|
101
|
-
.map(
|
|
100
|
+
.filter(item => item.type === "function_call" && typeof item.call_id === "string")
|
|
101
|
+
.map(item => item.call_id as string),
|
|
102
102
|
);
|
|
103
103
|
|
|
104
|
-
body.input = body.input.map(
|
|
104
|
+
body.input = body.input.map(item => {
|
|
105
105
|
if (item.type === "function_call_output" && typeof item.call_id === "string") {
|
|
106
106
|
const callId = item.call_id as string;
|
|
107
107
|
if (!functionCallIds.has(callId)) {
|
|
@@ -131,7 +131,7 @@ export async function transformRequestBody(
|
|
|
131
131
|
|
|
132
132
|
if (prompt?.developerMessages && prompt.developerMessages.length > 0 && Array.isArray(body.input)) {
|
|
133
133
|
const developerMessages = prompt.developerMessages.map(
|
|
134
|
-
|
|
134
|
+
text =>
|
|
135
135
|
({
|
|
136
136
|
type: "message",
|
|
137
137
|
role: "developer",
|