@oh-my-pi/pi-ai 11.8.1 → 11.8.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +2 -2
- package/src/cli.ts +5 -5
- package/src/providers/cursor.ts +1 -1
- package/src/providers/google-gemini-cli.ts +3 -3
- package/src/storage.ts +23 -23
- package/src/utils/event-stream.ts +40 -39
- package/src/utils/oauth/anthropic.ts +8 -11
- package/src/utils/oauth/callback-server.ts +23 -28
- package/src/utils/oauth/google-antigravity.ts +8 -11
- package/src/utils/oauth/google-gemini-cli.ts +8 -11
- package/src/utils/oauth/openai-codex.ts +2 -5
- package/src/utils/validation.ts +7 -30
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@oh-my-pi/pi-ai",
|
|
3
|
-
"version": "11.8.
|
|
3
|
+
"version": "11.8.3",
|
|
4
4
|
"description": "Unified LLM API with automatic model discovery and provider configuration",
|
|
5
5
|
"type": "module",
|
|
6
6
|
"main": "./src/index.ts",
|
|
@@ -63,7 +63,7 @@
|
|
|
63
63
|
"@connectrpc/connect-node": "^2.1.1",
|
|
64
64
|
"@google/genai": "^1.39.0",
|
|
65
65
|
"@mistralai/mistralai": "^1.13.0",
|
|
66
|
-
"@oh-my-pi/pi-utils": "11.8.
|
|
66
|
+
"@oh-my-pi/pi-utils": "11.8.3",
|
|
67
67
|
"@sinclair/typebox": "^0.34.48",
|
|
68
68
|
"@smithy/node-http-handler": "^4.4.9",
|
|
69
69
|
"ajv": "^8.17.1",
|
package/src/cli.ts
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
#!/usr/bin/env bun
|
|
2
|
-
import
|
|
2
|
+
import * as readline from "node:readline";
|
|
3
3
|
import { CliAuthStorage } from "./storage";
|
|
4
4
|
import { getOAuthProviders } from "./utils/oauth";
|
|
5
5
|
import { loginAnthropic } from "./utils/oauth/anthropic";
|
|
@@ -13,14 +13,14 @@ import type { OAuthCredentials, OAuthProvider } from "./utils/oauth/types";
|
|
|
13
13
|
|
|
14
14
|
const PROVIDERS = getOAuthProviders();
|
|
15
15
|
|
|
16
|
-
function prompt(rl:
|
|
16
|
+
function prompt(rl: readline.Interface, question: string): Promise<string> {
|
|
17
17
|
const { promise, resolve } = Promise.withResolvers<string>();
|
|
18
18
|
rl.question(question, resolve);
|
|
19
19
|
return promise;
|
|
20
20
|
}
|
|
21
21
|
|
|
22
22
|
async function login(provider: OAuthProvider): Promise<void> {
|
|
23
|
-
const rl = createInterface({ input: process.stdin, output: process.stdout });
|
|
23
|
+
const rl = readline.createInterface({ input: process.stdin, output: process.stdout });
|
|
24
24
|
|
|
25
25
|
const promptFn = (msg: string) => prompt(rl, `${msg} `);
|
|
26
26
|
const storage = await CliAuthStorage.create();
|
|
@@ -201,7 +201,7 @@ Examples:
|
|
|
201
201
|
return;
|
|
202
202
|
}
|
|
203
203
|
|
|
204
|
-
const rl = createInterface({ input: process.stdin, output: process.stdout });
|
|
204
|
+
const rl = readline.createInterface({ input: process.stdin, output: process.stdout });
|
|
205
205
|
console.log("Select a provider to logout:\n");
|
|
206
206
|
for (let i = 0; i < providers.length; i++) {
|
|
207
207
|
console.log(` ${i + 1}. ${providers[i]}`);
|
|
@@ -237,7 +237,7 @@ Examples:
|
|
|
237
237
|
let provider = args[1] as OAuthProvider | undefined;
|
|
238
238
|
|
|
239
239
|
if (!provider) {
|
|
240
|
-
const rl = createInterface({ input: process.stdin, output: process.stdout });
|
|
240
|
+
const rl = readline.createInterface({ input: process.stdin, output: process.stdout });
|
|
241
241
|
console.log("Select a provider:\n");
|
|
242
242
|
for (let i = 0; i < PROVIDERS.length; i++) {
|
|
243
243
|
console.log(` ${i + 1}. ${PROVIDERS[i].name}`);
|
package/src/providers/cursor.ts
CHANGED
|
@@ -317,7 +317,7 @@ export const streamCursor: StreamFunction<"cursor-agent"> = (
|
|
|
317
317
|
|
|
318
318
|
let h2Client: http2.ClientHttp2Session | null = null;
|
|
319
319
|
let h2Request: http2.ClientHttp2Stream | null = null;
|
|
320
|
-
let heartbeatTimer:
|
|
320
|
+
let heartbeatTimer: NodeJS.Timeout | null = null;
|
|
321
321
|
|
|
322
322
|
try {
|
|
323
323
|
const apiKey = options?.apiKey;
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
* Uses the Cloud Code Assist API endpoint to access Gemini and Claude models.
|
|
5
5
|
*/
|
|
6
6
|
import { createHash } from "node:crypto";
|
|
7
|
-
import type { Content, ThinkingConfig } from "@google/genai";
|
|
7
|
+
import type { Content, FunctionCallingConfigMode, ThinkingConfig } from "@google/genai";
|
|
8
8
|
import { abortableSleep, readSseJson } from "@oh-my-pi/pi-utils";
|
|
9
9
|
import { calculateCost } from "../models";
|
|
10
10
|
import type {
|
|
@@ -244,10 +244,10 @@ interface CloudCodeAssistRequest {
|
|
|
244
244
|
temperature?: number;
|
|
245
245
|
thinkingConfig?: ThinkingConfig;
|
|
246
246
|
};
|
|
247
|
-
tools?:
|
|
247
|
+
tools?: { functionDeclarations: Record<string, unknown>[] }[] | undefined;
|
|
248
248
|
toolConfig?: {
|
|
249
249
|
functionCallingConfig: {
|
|
250
|
-
mode:
|
|
250
|
+
mode: FunctionCallingConfigMode;
|
|
251
251
|
};
|
|
252
252
|
};
|
|
253
253
|
};
|
package/src/storage.ts
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
* Compatible with coding-agent's agent.db format.
|
|
4
4
|
*/
|
|
5
5
|
|
|
6
|
-
import { Database } from "bun:sqlite";
|
|
6
|
+
import { Database, type Statement } from "bun:sqlite";
|
|
7
7
|
import * as fs from "node:fs/promises";
|
|
8
8
|
import * as os from "node:os";
|
|
9
9
|
import * as path from "node:path";
|
|
@@ -84,22 +84,22 @@ function deserializeCredential(row: AuthRow): AuthCredential | null {
|
|
|
84
84
|
* Use `CliAuthStorage.create()` to instantiate (async initialization).
|
|
85
85
|
*/
|
|
86
86
|
export class CliAuthStorage {
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
87
|
+
#db: Database;
|
|
88
|
+
#insertStmt: Statement;
|
|
89
|
+
#listByProviderStmt: Statement;
|
|
90
|
+
#listAllStmt: Statement;
|
|
91
|
+
#deleteByProviderStmt: Statement;
|
|
92
92
|
|
|
93
93
|
private constructor(db: Database) {
|
|
94
|
-
this
|
|
95
|
-
this
|
|
94
|
+
this.#db = db;
|
|
95
|
+
this.#initializeSchema();
|
|
96
96
|
|
|
97
|
-
this
|
|
97
|
+
this.#insertStmt = this.#db.prepare(
|
|
98
98
|
"INSERT INTO auth_credentials (provider, credential_type, data) VALUES (?, ?, ?) RETURNING id",
|
|
99
99
|
);
|
|
100
|
-
this
|
|
101
|
-
this
|
|
102
|
-
this
|
|
100
|
+
this.#listByProviderStmt = this.#db.prepare("SELECT * FROM auth_credentials WHERE provider = ?");
|
|
101
|
+
this.#listAllStmt = this.#db.prepare("SELECT * FROM auth_credentials");
|
|
102
|
+
this.#deleteByProviderStmt = this.#db.prepare("DELETE FROM auth_credentials WHERE provider = ?");
|
|
103
103
|
}
|
|
104
104
|
|
|
105
105
|
static async create(dbPath: string = getAgentDbPath()): Promise<CliAuthStorage> {
|
|
@@ -122,8 +122,8 @@ export class CliAuthStorage {
|
|
|
122
122
|
return new CliAuthStorage(db);
|
|
123
123
|
}
|
|
124
124
|
|
|
125
|
-
|
|
126
|
-
this
|
|
125
|
+
#initializeSchema(): void {
|
|
126
|
+
this.#db.exec(`
|
|
127
127
|
PRAGMA journal_mode=WAL;
|
|
128
128
|
PRAGMA synchronous=NORMAL;
|
|
129
129
|
PRAGMA busy_timeout=5000;
|
|
@@ -145,14 +145,14 @@ CREATE INDEX IF NOT EXISTS idx_auth_provider ON auth_credentials(provider);
|
|
|
145
145
|
*/
|
|
146
146
|
saveOAuth(provider: string, credentials: OAuthCredentials): void {
|
|
147
147
|
const credential: AuthCredential = { type: "oauth", ...credentials };
|
|
148
|
-
this
|
|
148
|
+
this.#replaceForProvider(provider, credential);
|
|
149
149
|
}
|
|
150
150
|
|
|
151
151
|
/**
|
|
152
152
|
* Get OAuth credentials for a provider.
|
|
153
153
|
*/
|
|
154
154
|
getOAuth(provider: string): OAuthCredentials | null {
|
|
155
|
-
const rows = this
|
|
155
|
+
const rows = this.#listByProviderStmt.all(provider) as AuthRow[];
|
|
156
156
|
for (const row of rows) {
|
|
157
157
|
const credential = deserializeCredential(row);
|
|
158
158
|
if (credential && credential.type === "oauth") {
|
|
@@ -167,7 +167,7 @@ CREATE INDEX IF NOT EXISTS idx_auth_provider ON auth_credentials(provider);
|
|
|
167
167
|
* List all providers with credentials.
|
|
168
168
|
*/
|
|
169
169
|
listProviders(): string[] {
|
|
170
|
-
const rows = this
|
|
170
|
+
const rows = this.#listAllStmt.all() as AuthRow[];
|
|
171
171
|
const providers = new Set<string>();
|
|
172
172
|
for (const row of rows) {
|
|
173
173
|
providers.add(row.provider);
|
|
@@ -179,24 +179,24 @@ CREATE INDEX IF NOT EXISTS idx_auth_provider ON auth_credentials(provider);
|
|
|
179
179
|
* Delete all credentials for a provider.
|
|
180
180
|
*/
|
|
181
181
|
deleteProvider(provider: string): void {
|
|
182
|
-
this
|
|
182
|
+
this.#deleteByProviderStmt.run(provider);
|
|
183
183
|
}
|
|
184
184
|
|
|
185
185
|
/**
|
|
186
186
|
* Replace all credentials for a provider with a single credential.
|
|
187
187
|
*/
|
|
188
|
-
|
|
188
|
+
#replaceForProvider(provider: string, credential: AuthCredential): void {
|
|
189
189
|
const serialized = serializeCredential(credential);
|
|
190
190
|
if (!serialized) return;
|
|
191
191
|
|
|
192
|
-
const replace = this
|
|
193
|
-
this
|
|
194
|
-
this
|
|
192
|
+
const replace = this.#db.transaction(() => {
|
|
193
|
+
this.#deleteByProviderStmt.run(provider);
|
|
194
|
+
this.#insertStmt.run(provider, serialized.credentialType, serialized.data);
|
|
195
195
|
});
|
|
196
196
|
replace();
|
|
197
197
|
}
|
|
198
198
|
|
|
199
199
|
close(): void {
|
|
200
|
-
this
|
|
200
|
+
this.#db.close();
|
|
201
201
|
}
|
|
202
202
|
}
|
|
@@ -2,19 +2,20 @@ import type { AssistantMessage, AssistantMessageEvent } from "../types";
|
|
|
2
2
|
|
|
3
3
|
// Generic event stream class for async iteration
|
|
4
4
|
export class EventStream<T, R = T> implements AsyncIterable<T> {
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
) {
|
|
5
|
+
queue: T[] = [];
|
|
6
|
+
waiting: ((value: IteratorResult<T>) => void)[] = [];
|
|
7
|
+
done = false;
|
|
8
|
+
finalResultPromise: Promise<R>;
|
|
9
|
+
resolveFinalResult!: (result: R) => void;
|
|
10
|
+
isComplete: (event: T) => boolean;
|
|
11
|
+
extractResult: (event: T) => R;
|
|
12
|
+
|
|
13
|
+
constructor(isComplete: (event: T) => boolean, extractResult: (event: T) => R) {
|
|
15
14
|
const { promise, resolve } = Promise.withResolvers<R>();
|
|
16
15
|
this.finalResultPromise = promise;
|
|
17
16
|
this.resolveFinalResult = resolve;
|
|
17
|
+
this.isComplete = isComplete;
|
|
18
|
+
this.extractResult = extractResult;
|
|
18
19
|
}
|
|
19
20
|
|
|
20
21
|
push(event: T): void {
|
|
@@ -34,7 +35,7 @@ export class EventStream<T, R = T> implements AsyncIterable<T> {
|
|
|
34
35
|
}
|
|
35
36
|
}
|
|
36
37
|
|
|
37
|
-
|
|
38
|
+
deliver(event: T): void {
|
|
38
39
|
const waiter = this.waiting.shift();
|
|
39
40
|
if (waiter) {
|
|
40
41
|
waiter({ value: event, done: false });
|
|
@@ -55,7 +56,7 @@ export class EventStream<T, R = T> implements AsyncIterable<T> {
|
|
|
55
56
|
}
|
|
56
57
|
}
|
|
57
58
|
|
|
58
|
-
|
|
59
|
+
endWaiting(): void {
|
|
59
60
|
while (this.waiting.length > 0) {
|
|
60
61
|
const waiter = this.waiting.shift()!;
|
|
61
62
|
waiter({ value: undefined as any, done: true });
|
|
@@ -93,10 +94,10 @@ function isDeltaEvent(event: AssistantMessageEvent): event is DeltaEvent {
|
|
|
93
94
|
|
|
94
95
|
export class AssistantMessageEventStream extends EventStream<AssistantMessageEvent, AssistantMessage> {
|
|
95
96
|
// Throttling state
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
97
|
+
#deltaBuffer: DeltaEvent[] = [];
|
|
98
|
+
#flushTimer?: NodeJS.Timeout;
|
|
99
|
+
#lastFlushTime = 0;
|
|
100
|
+
readonly #throttleMs = 50; // 20 updates/sec
|
|
100
101
|
|
|
101
102
|
constructor() {
|
|
102
103
|
super(
|
|
@@ -117,25 +118,25 @@ export class AssistantMessageEventStream extends EventStream<AssistantMessageEve
|
|
|
117
118
|
|
|
118
119
|
// Check for completion first
|
|
119
120
|
if (this.isComplete(event)) {
|
|
120
|
-
this
|
|
121
|
+
this.#flushDeltas(); // Flush any pending deltas before completing
|
|
121
122
|
this.done = true;
|
|
122
123
|
this.resolveFinalResult(this.extractResult(event));
|
|
123
124
|
}
|
|
124
125
|
|
|
125
126
|
// Delta events get batched and throttled
|
|
126
127
|
if (isDeltaEvent(event)) {
|
|
127
|
-
this
|
|
128
|
-
this
|
|
128
|
+
this.#deltaBuffer.push(event);
|
|
129
|
+
this.#scheduleFlush();
|
|
129
130
|
return;
|
|
130
131
|
}
|
|
131
132
|
|
|
132
133
|
// Non-delta events flush pending deltas immediately, then emit
|
|
133
|
-
this
|
|
134
|
+
this.#flushDeltas();
|
|
134
135
|
this.deliver(event);
|
|
135
136
|
}
|
|
136
137
|
|
|
137
138
|
override end(result?: AssistantMessage): void {
|
|
138
|
-
this
|
|
139
|
+
this.#flushDeltas();
|
|
139
140
|
this.done = true;
|
|
140
141
|
if (result !== undefined) {
|
|
141
142
|
this.resolveFinalResult(result);
|
|
@@ -143,44 +144,44 @@ export class AssistantMessageEventStream extends EventStream<AssistantMessageEve
|
|
|
143
144
|
this.endWaiting();
|
|
144
145
|
}
|
|
145
146
|
|
|
146
|
-
|
|
147
|
-
if (this
|
|
147
|
+
#scheduleFlush(): void {
|
|
148
|
+
if (this.#flushTimer) return; // Already scheduled
|
|
148
149
|
|
|
149
150
|
const now = Bun.nanoseconds();
|
|
150
|
-
const timeSinceLastFlush = (now - this
|
|
151
|
+
const timeSinceLastFlush = (now - this.#lastFlushTime) / 1e6;
|
|
151
152
|
|
|
152
|
-
if (timeSinceLastFlush >= this
|
|
153
|
+
if (timeSinceLastFlush >= this.#throttleMs) {
|
|
153
154
|
// Flush immediately if throttle window has passed
|
|
154
|
-
this
|
|
155
|
+
this.#flushDeltas();
|
|
155
156
|
} else {
|
|
156
157
|
// Schedule flush for when throttle window expires
|
|
157
|
-
const delay = this
|
|
158
|
-
this
|
|
159
|
-
this
|
|
160
|
-
this
|
|
158
|
+
const delay = this.#throttleMs - timeSinceLastFlush;
|
|
159
|
+
this.#flushTimer = setTimeout(() => {
|
|
160
|
+
this.#flushTimer = undefined;
|
|
161
|
+
this.#flushDeltas();
|
|
161
162
|
}, delay);
|
|
162
163
|
}
|
|
163
164
|
}
|
|
164
165
|
|
|
165
|
-
|
|
166
|
-
if (this
|
|
167
|
-
clearTimeout(this
|
|
168
|
-
this
|
|
166
|
+
#flushDeltas(): void {
|
|
167
|
+
if (this.#flushTimer) {
|
|
168
|
+
clearTimeout(this.#flushTimer);
|
|
169
|
+
this.#flushTimer = undefined;
|
|
169
170
|
}
|
|
170
171
|
|
|
171
|
-
if (this
|
|
172
|
+
if (this.#deltaBuffer.length === 0) return;
|
|
172
173
|
|
|
173
174
|
// Merge consecutive deltas for the same content block and type
|
|
174
|
-
const merged = this
|
|
175
|
-
this
|
|
176
|
-
this
|
|
175
|
+
const merged = this.#mergeDeltas(this.#deltaBuffer);
|
|
176
|
+
this.#deltaBuffer = [];
|
|
177
|
+
this.#lastFlushTime = Bun.nanoseconds();
|
|
177
178
|
|
|
178
179
|
for (const event of merged) {
|
|
179
180
|
this.deliver(event);
|
|
180
181
|
}
|
|
181
182
|
}
|
|
182
183
|
|
|
183
|
-
|
|
184
|
+
#mergeDeltas(deltas: DeltaEvent[]): AssistantMessageEvent[] {
|
|
184
185
|
if (deltas.length === 0) return [];
|
|
185
186
|
if (deltas.length === 1) return [deltas[0]];
|
|
186
187
|
|
|
@@ -14,20 +14,17 @@ const CALLBACK_PATH = "/callback";
|
|
|
14
14
|
const SCOPES = "org:create_api_key user:profile user:inference";
|
|
15
15
|
|
|
16
16
|
class AnthropicOAuthFlow extends OAuthCallbackFlow {
|
|
17
|
-
|
|
18
|
-
|
|
17
|
+
#verifier: string = "";
|
|
18
|
+
#challenge: string = "";
|
|
19
19
|
|
|
20
20
|
constructor(ctrl: OAuthController) {
|
|
21
21
|
super(ctrl, CALLBACK_PORT, CALLBACK_PATH);
|
|
22
22
|
}
|
|
23
23
|
|
|
24
|
-
|
|
25
|
-
state: string,
|
|
26
|
-
redirectUri: string,
|
|
27
|
-
): Promise<{ url: string; instructions?: string }> {
|
|
24
|
+
async generateAuthUrl(state: string, redirectUri: string): Promise<{ url: string; instructions?: string }> {
|
|
28
25
|
const pkce = await generatePKCE();
|
|
29
|
-
this
|
|
30
|
-
this
|
|
26
|
+
this.#verifier = pkce.verifier;
|
|
27
|
+
this.#challenge = pkce.challenge;
|
|
31
28
|
|
|
32
29
|
const authParams = new URLSearchParams({
|
|
33
30
|
code: "true",
|
|
@@ -35,7 +32,7 @@ class AnthropicOAuthFlow extends OAuthCallbackFlow {
|
|
|
35
32
|
response_type: "code",
|
|
36
33
|
redirect_uri: redirectUri,
|
|
37
34
|
scope: SCOPES,
|
|
38
|
-
code_challenge: this
|
|
35
|
+
code_challenge: this.#challenge,
|
|
39
36
|
code_challenge_method: "S256",
|
|
40
37
|
state,
|
|
41
38
|
});
|
|
@@ -44,7 +41,7 @@ class AnthropicOAuthFlow extends OAuthCallbackFlow {
|
|
|
44
41
|
return { url };
|
|
45
42
|
}
|
|
46
43
|
|
|
47
|
-
|
|
44
|
+
async exchangeToken(code: string, state: string, redirectUri: string): Promise<OAuthCredentials> {
|
|
48
45
|
const tokenResponse = await fetch(TOKEN_URL, {
|
|
49
46
|
method: "POST",
|
|
50
47
|
headers: {
|
|
@@ -57,7 +54,7 @@ class AnthropicOAuthFlow extends OAuthCallbackFlow {
|
|
|
57
54
|
code,
|
|
58
55
|
state,
|
|
59
56
|
redirect_uri: redirectUri,
|
|
60
|
-
code_verifier: this
|
|
57
|
+
code_verifier: this.#verifier,
|
|
61
58
|
}),
|
|
62
59
|
});
|
|
63
60
|
|
|
@@ -23,11 +23,11 @@ export type CallbackResult = { code: string; state: string };
|
|
|
23
23
|
* Abstract base class for OAuth flows with local callback servers.
|
|
24
24
|
*/
|
|
25
25
|
export abstract class OAuthCallbackFlow {
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
26
|
+
ctrl: OAuthController;
|
|
27
|
+
preferredPort: number;
|
|
28
|
+
callbackPath: string;
|
|
29
|
+
#callbackResolve?: (result: CallbackResult) => void;
|
|
30
|
+
#callbackReject?: (error: string) => void;
|
|
31
31
|
|
|
32
32
|
constructor(ctrl: OAuthController, preferredPort: number, callbackPath: string = CALLBACK_PATH) {
|
|
33
33
|
this.ctrl = ctrl;
|
|
@@ -41,10 +41,7 @@ export abstract class OAuthCallbackFlow {
|
|
|
41
41
|
* @param redirectUri - The actual redirect URI to use (may differ from expected if port fallback occurred)
|
|
42
42
|
* @returns Authorization URL and optional instructions
|
|
43
43
|
*/
|
|
44
|
-
|
|
45
|
-
state: string,
|
|
46
|
-
redirectUri: string,
|
|
47
|
-
): Promise<{ url: string; instructions?: string }>;
|
|
44
|
+
abstract generateAuthUrl(state: string, redirectUri: string): Promise<{ url: string; instructions?: string }>;
|
|
48
45
|
|
|
49
46
|
/**
|
|
50
47
|
* Exchange authorization code for OAuth tokens.
|
|
@@ -53,12 +50,12 @@ export abstract class OAuthCallbackFlow {
|
|
|
53
50
|
* @param redirectUri - The actual redirect URI used (must match authorization request)
|
|
54
51
|
* @returns OAuth credentials
|
|
55
52
|
*/
|
|
56
|
-
|
|
53
|
+
abstract exchangeToken(code: string, state: string, redirectUri: string): Promise<OAuthCredentials>;
|
|
57
54
|
|
|
58
55
|
/**
|
|
59
56
|
* Generate CSRF state token. Override if provider needs custom state generation.
|
|
60
57
|
*/
|
|
61
|
-
|
|
58
|
+
generateState(): string {
|
|
62
59
|
const bytes = new Uint8Array(16);
|
|
63
60
|
crypto.getRandomValues(bytes);
|
|
64
61
|
return Array.from(bytes)
|
|
@@ -73,7 +70,7 @@ export abstract class OAuthCallbackFlow {
|
|
|
73
70
|
const state = this.generateState();
|
|
74
71
|
|
|
75
72
|
// Start callback server first to get actual redirect URI
|
|
76
|
-
const { server, redirectUri } = await this
|
|
73
|
+
const { server, redirectUri } = await this.#startCallbackServer(state);
|
|
77
74
|
|
|
78
75
|
try {
|
|
79
76
|
// Generate auth URL with the ACTUAL redirect URI (may differ from expected if port was busy)
|
|
@@ -84,7 +81,7 @@ export abstract class OAuthCallbackFlow {
|
|
|
84
81
|
this.ctrl.onProgress?.("Waiting for browser authentication...");
|
|
85
82
|
|
|
86
83
|
// Wait for callback or manual input
|
|
87
|
-
const { code } = await this
|
|
84
|
+
const { code } = await this.#waitForCallback(state);
|
|
88
85
|
|
|
89
86
|
this.ctrl.onProgress?.("Exchanging authorization code for tokens...");
|
|
90
87
|
|
|
@@ -97,18 +94,16 @@ export abstract class OAuthCallbackFlow {
|
|
|
97
94
|
/**
|
|
98
95
|
* Start callback server, trying preferred port first, falling back to random.
|
|
99
96
|
*/
|
|
100
|
-
|
|
101
|
-
expectedState: string,
|
|
102
|
-
): Promise<{ server: Bun.Server<unknown>; redirectUri: string }> {
|
|
97
|
+
async #startCallbackServer(expectedState: string): Promise<{ server: Bun.Server<unknown>; redirectUri: string }> {
|
|
103
98
|
// Try preferred port first
|
|
104
99
|
try {
|
|
105
100
|
const redirectUri = `http://${DEFAULT_HOSTNAME}:${this.preferredPort}${this.callbackPath}`;
|
|
106
|
-
const server = this
|
|
101
|
+
const server = this.#createServer(this.preferredPort, expectedState);
|
|
107
102
|
return { server, redirectUri };
|
|
108
103
|
} catch {
|
|
109
104
|
// Port busy or unavailable, try random port
|
|
110
105
|
const randomPort = 0; // Let OS assign
|
|
111
|
-
const server = this
|
|
106
|
+
const server = this.#createServer(randomPort, expectedState);
|
|
112
107
|
const actualPort = server.port;
|
|
113
108
|
const redirectUri = `http://${DEFAULT_HOSTNAME}:${actualPort}${this.callbackPath}`;
|
|
114
109
|
this.ctrl.onProgress?.(`Preferred port ${this.preferredPort} unavailable, using port ${actualPort}`);
|
|
@@ -119,19 +114,19 @@ export abstract class OAuthCallbackFlow {
|
|
|
119
114
|
/**
|
|
120
115
|
* Create HTTP server for OAuth callback.
|
|
121
116
|
*/
|
|
122
|
-
|
|
117
|
+
#createServer(port: number, expectedState: string): Bun.Server<unknown> {
|
|
123
118
|
return Bun.serve({
|
|
124
119
|
hostname: DEFAULT_HOSTNAME,
|
|
125
120
|
port,
|
|
126
121
|
reusePort: false,
|
|
127
|
-
fetch: req => this
|
|
122
|
+
fetch: req => this.#handleCallback(req, expectedState),
|
|
128
123
|
});
|
|
129
124
|
}
|
|
130
125
|
|
|
131
126
|
/**
|
|
132
127
|
* Handle OAuth callback HTTP request.
|
|
133
128
|
*/
|
|
134
|
-
|
|
129
|
+
#handleCallback(req: Request, expectedState: string): Response {
|
|
135
130
|
const url = new URL(req.url);
|
|
136
131
|
|
|
137
132
|
if (url.pathname !== this.callbackPath) {
|
|
@@ -158,8 +153,8 @@ export abstract class OAuthCallbackFlow {
|
|
|
158
153
|
}
|
|
159
154
|
|
|
160
155
|
// Signal to waitForCallback - capture refs before they could be cleared
|
|
161
|
-
const resolve = this
|
|
162
|
-
const reject = this
|
|
156
|
+
const resolve = this.#callbackResolve;
|
|
157
|
+
const reject = this.#callbackReject;
|
|
163
158
|
queueMicrotask(() => {
|
|
164
159
|
if (resultState.ok) {
|
|
165
160
|
resolve?.({ code: resultState.code, state: resultState.state });
|
|
@@ -180,17 +175,17 @@ export abstract class OAuthCallbackFlow {
|
|
|
180
175
|
/**
|
|
181
176
|
* Wait for OAuth callback or manual input (whichever comes first).
|
|
182
177
|
*/
|
|
183
|
-
|
|
178
|
+
#waitForCallback(expectedState: string): Promise<CallbackResult> {
|
|
184
179
|
const timeoutSignal = AbortSignal.timeout(DEFAULT_TIMEOUT);
|
|
185
180
|
const signal = this.ctrl.signal ? AbortSignal.any([this.ctrl.signal, timeoutSignal]) : timeoutSignal;
|
|
186
181
|
|
|
187
182
|
const callbackPromise = new Promise<CallbackResult>((resolve, reject) => {
|
|
188
|
-
this
|
|
189
|
-
this
|
|
183
|
+
this.#callbackResolve = resolve;
|
|
184
|
+
this.#callbackReject = reject;
|
|
190
185
|
|
|
191
186
|
signal.addEventListener("abort", () => {
|
|
192
|
-
this
|
|
193
|
-
this
|
|
187
|
+
this.#callbackResolve = undefined;
|
|
188
|
+
this.#callbackReject = undefined;
|
|
194
189
|
reject(new Error(`OAuth callback cancelled: ${signal.reason}`));
|
|
195
190
|
});
|
|
196
191
|
});
|
|
@@ -103,27 +103,24 @@ async function getUserEmail(accessToken: string): Promise<string | undefined> {
|
|
|
103
103
|
}
|
|
104
104
|
|
|
105
105
|
class AntigravityOAuthFlow extends OAuthCallbackFlow {
|
|
106
|
-
|
|
107
|
-
|
|
106
|
+
#verifier: string = "";
|
|
107
|
+
#challenge: string = "";
|
|
108
108
|
|
|
109
109
|
constructor(ctrl: OAuthController) {
|
|
110
110
|
super(ctrl, CALLBACK_PORT, CALLBACK_PATH);
|
|
111
111
|
}
|
|
112
112
|
|
|
113
|
-
|
|
114
|
-
state: string,
|
|
115
|
-
redirectUri: string,
|
|
116
|
-
): Promise<{ url: string; instructions?: string }> {
|
|
113
|
+
async generateAuthUrl(state: string, redirectUri: string): Promise<{ url: string; instructions?: string }> {
|
|
117
114
|
const pkce = await generatePKCE();
|
|
118
|
-
this
|
|
119
|
-
this
|
|
115
|
+
this.#verifier = pkce.verifier;
|
|
116
|
+
this.#challenge = pkce.challenge;
|
|
120
117
|
|
|
121
118
|
const authParams = new URLSearchParams({
|
|
122
119
|
client_id: CLIENT_ID,
|
|
123
120
|
response_type: "code",
|
|
124
121
|
redirect_uri: redirectUri,
|
|
125
122
|
scope: SCOPES.join(" "),
|
|
126
|
-
code_challenge: this
|
|
123
|
+
code_challenge: this.#challenge,
|
|
127
124
|
code_challenge_method: "S256",
|
|
128
125
|
state,
|
|
129
126
|
access_type: "offline",
|
|
@@ -134,7 +131,7 @@ class AntigravityOAuthFlow extends OAuthCallbackFlow {
|
|
|
134
131
|
return { url, instructions: "Complete the sign-in in your browser." };
|
|
135
132
|
}
|
|
136
133
|
|
|
137
|
-
|
|
134
|
+
async exchangeToken(code: string, _state: string, redirectUri: string): Promise<OAuthCredentials> {
|
|
138
135
|
this.ctrl.onProgress?.("Exchanging authorization code for tokens...");
|
|
139
136
|
|
|
140
137
|
const tokenResponse = await fetch(TOKEN_URL, {
|
|
@@ -146,7 +143,7 @@ class AntigravityOAuthFlow extends OAuthCallbackFlow {
|
|
|
146
143
|
code,
|
|
147
144
|
grant_type: "authorization_code",
|
|
148
145
|
redirect_uri: redirectUri,
|
|
149
|
-
code_verifier: this
|
|
146
|
+
code_verifier: this.#verifier,
|
|
150
147
|
}),
|
|
151
148
|
});
|
|
152
149
|
|
|
@@ -228,27 +228,24 @@ async function getUserEmail(accessToken: string): Promise<string | undefined> {
|
|
|
228
228
|
}
|
|
229
229
|
|
|
230
230
|
class GeminiCliOAuthFlow extends OAuthCallbackFlow {
|
|
231
|
-
|
|
232
|
-
|
|
231
|
+
#verifier: string = "";
|
|
232
|
+
#challenge: string = "";
|
|
233
233
|
|
|
234
234
|
constructor(ctrl: OAuthController) {
|
|
235
235
|
super(ctrl, CALLBACK_PORT, CALLBACK_PATH);
|
|
236
236
|
}
|
|
237
237
|
|
|
238
|
-
|
|
239
|
-
state: string,
|
|
240
|
-
redirectUri: string,
|
|
241
|
-
): Promise<{ url: string; instructions?: string }> {
|
|
238
|
+
async generateAuthUrl(state: string, redirectUri: string): Promise<{ url: string; instructions?: string }> {
|
|
242
239
|
const pkce = await generatePKCE();
|
|
243
|
-
this
|
|
244
|
-
this
|
|
240
|
+
this.#verifier = pkce.verifier;
|
|
241
|
+
this.#challenge = pkce.challenge;
|
|
245
242
|
|
|
246
243
|
const authParams = new URLSearchParams({
|
|
247
244
|
client_id: CLIENT_ID,
|
|
248
245
|
response_type: "code",
|
|
249
246
|
redirect_uri: redirectUri,
|
|
250
247
|
scope: SCOPES.join(" "),
|
|
251
|
-
code_challenge: this
|
|
248
|
+
code_challenge: this.#challenge,
|
|
252
249
|
code_challenge_method: "S256",
|
|
253
250
|
state,
|
|
254
251
|
access_type: "offline",
|
|
@@ -259,7 +256,7 @@ class GeminiCliOAuthFlow extends OAuthCallbackFlow {
|
|
|
259
256
|
return { url, instructions: "Complete the sign-in in your browser." };
|
|
260
257
|
}
|
|
261
258
|
|
|
262
|
-
|
|
259
|
+
async exchangeToken(code: string, _state: string, redirectUri: string): Promise<OAuthCredentials> {
|
|
263
260
|
this.ctrl.onProgress?.("Exchanging authorization code for tokens...");
|
|
264
261
|
|
|
265
262
|
const tokenResponse = await fetch(TOKEN_URL, {
|
|
@@ -271,7 +268,7 @@ class GeminiCliOAuthFlow extends OAuthCallbackFlow {
|
|
|
271
268
|
code,
|
|
272
269
|
grant_type: "authorization_code",
|
|
273
270
|
redirect_uri: redirectUri,
|
|
274
|
-
code_verifier: this
|
|
271
|
+
code_verifier: this.#verifier,
|
|
275
272
|
}),
|
|
276
273
|
});
|
|
277
274
|
|
|
@@ -53,10 +53,7 @@ class OpenAICodexOAuthFlow extends OAuthCallbackFlow {
|
|
|
53
53
|
super(ctrl, CALLBACK_PORT, CALLBACK_PATH);
|
|
54
54
|
}
|
|
55
55
|
|
|
56
|
-
|
|
57
|
-
state: string,
|
|
58
|
-
redirectUri: string,
|
|
59
|
-
): Promise<{ url: string; instructions?: string }> {
|
|
56
|
+
async generateAuthUrl(state: string, redirectUri: string): Promise<{ url: string; instructions?: string }> {
|
|
60
57
|
const searchParams = new URLSearchParams({
|
|
61
58
|
response_type: "code",
|
|
62
59
|
client_id: CLIENT_ID,
|
|
@@ -74,7 +71,7 @@ class OpenAICodexOAuthFlow extends OAuthCallbackFlow {
|
|
|
74
71
|
return { url, instructions: "A browser window should open. Complete login to finish." };
|
|
75
72
|
}
|
|
76
73
|
|
|
77
|
-
|
|
74
|
+
async exchangeToken(code: string, _state: string, redirectUri: string): Promise<OAuthCredentials> {
|
|
78
75
|
return exchangeCodeForToken(code, this.pkce.verifier, redirectUri);
|
|
79
76
|
}
|
|
80
77
|
}
|
package/src/utils/validation.ts
CHANGED
|
@@ -1,11 +1,7 @@
|
|
|
1
|
-
import
|
|
2
|
-
import
|
|
1
|
+
import Ajv from "ajv";
|
|
2
|
+
import addFormats from "ajv-formats";
|
|
3
3
|
import type { Tool, ToolCall } from "../types";
|
|
4
4
|
|
|
5
|
-
// Handle both default and named exports (ESM/CJS interop)
|
|
6
|
-
const Ajv = (AjvModule as any).default || AjvModule;
|
|
7
|
-
const addFormats = (addFormatsModule as any).default || addFormatsModule;
|
|
8
|
-
|
|
9
5
|
// ============================================================================
|
|
10
6
|
// Type Coercion Utilities
|
|
11
7
|
// ============================================================================
|
|
@@ -281,25 +277,13 @@ function coerceArgsFromErrors(
|
|
|
281
277
|
return { value: changed ? nextArgs : args, changed };
|
|
282
278
|
}
|
|
283
279
|
|
|
284
|
-
// Detect if we're in a browser extension environment with strict CSP
|
|
285
|
-
// Chrome extensions with Manifest V3 don't allow eval/Function constructor
|
|
286
|
-
const isBrowserExtension = typeof globalThis !== "undefined" && (globalThis as any).chrome?.runtime?.id !== undefined;
|
|
287
|
-
|
|
288
280
|
// Create a singleton AJV instance with formats (only if not in browser extension)
|
|
289
281
|
// AJV requires 'unsafe-eval' CSP which is not allowed in Manifest V3
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
strict: false,
|
|
296
|
-
});
|
|
297
|
-
addFormats(ajv);
|
|
298
|
-
} catch {
|
|
299
|
-
// AJV initialization failed (likely CSP restriction)
|
|
300
|
-
console.warn("AJV validation disabled due to CSP restrictions");
|
|
301
|
-
}
|
|
302
|
-
}
|
|
282
|
+
const ajv = new Ajv({
|
|
283
|
+
allErrors: true,
|
|
284
|
+
strict: false,
|
|
285
|
+
});
|
|
286
|
+
addFormats(ajv);
|
|
303
287
|
|
|
304
288
|
/**
|
|
305
289
|
* Finds a tool by name and validates the tool call arguments against its TypeBox schema
|
|
@@ -326,13 +310,6 @@ export function validateToolCall(tools: Tool[], toolCall: ToolCall): any {
|
|
|
326
310
|
export function validateToolArguments(tool: Tool, toolCall: ToolCall): any {
|
|
327
311
|
const originalArgs = toolCall.arguments;
|
|
328
312
|
|
|
329
|
-
// Skip validation in browser extension environment (CSP restrictions prevent AJV from working)
|
|
330
|
-
if (!ajv || isBrowserExtension) {
|
|
331
|
-
// Trust the LLM's output without validation
|
|
332
|
-
// Browser extensions can't use AJV due to Manifest V3 CSP restrictions
|
|
333
|
-
return originalArgs;
|
|
334
|
-
}
|
|
335
|
-
|
|
336
313
|
// Compile the schema
|
|
337
314
|
const validate = ajv.compile(tool.parameters);
|
|
338
315
|
|