@oh-my-pi/pi-coding-agent 11.8.3 → 11.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +42 -0
- package/package.json +7 -7
- package/src/capability/mcp.ts +9 -0
- package/src/config/file-lock.ts +1 -1
- package/src/discovery/builtin.ts +48 -0
- package/src/discovery/mcp-json.ts +33 -0
- package/src/extensibility/slash-commands.ts +1 -0
- package/src/index.ts +0 -2
- package/src/mcp/config-writer.ts +194 -0
- package/src/mcp/config.ts +20 -6
- package/src/mcp/index.ts +4 -0
- package/src/mcp/loader.ts +6 -0
- package/src/mcp/manager.ts +92 -3
- package/src/mcp/oauth-discovery.ts +274 -0
- package/src/mcp/oauth-flow.ts +229 -0
- package/src/mcp/tool-bridge.ts +8 -8
- package/src/mcp/transports/http.ts +76 -35
- package/src/mcp/transports/stdio.ts +31 -16
- package/src/mcp/types.ts +15 -1
- package/src/modes/components/mcp-add-wizard.ts +1286 -0
- package/src/modes/components/tool-execution.ts +12 -24
- package/src/modes/controllers/input-controller.ts +8 -0
- package/src/modes/controllers/mcp-command-controller.ts +1223 -0
- package/src/modes/interactive-mode.ts +6 -0
- package/src/modes/types.ts +1 -0
- package/src/sdk.ts +1 -0
- package/src/session/agent-session.ts +49 -0
- package/src/system-prompt.ts +2 -3
- package/src/task/executor.ts +26 -38
- package/src/task/worktree.ts +8 -5
- package/src/tools/bash.ts +8 -4
- package/src/tools/browser.ts +7 -4
- package/src/tools/grep.ts +1 -13
- package/src/tools/index.ts +1 -1
- package/src/utils/event-bus.ts +3 -1
package/src/mcp/manager.ts
CHANGED
|
@@ -6,7 +6,10 @@
|
|
|
6
6
|
*/
|
|
7
7
|
import { logger } from "@oh-my-pi/pi-utils";
|
|
8
8
|
import type { TSchema } from "@sinclair/typebox";
|
|
9
|
+
import type { SourceMeta } from "../capability/types";
|
|
10
|
+
import { resolveConfigValue } from "../config/resolve-config-value";
|
|
9
11
|
import type { CustomTool } from "../extensibility/custom-tools/types";
|
|
12
|
+
import type { AuthStorage } from "../session/auth-storage";
|
|
10
13
|
import { connectToServer, disconnectServer, listTools } from "./client";
|
|
11
14
|
import { loadAllMCPConfigs, validateServerConfig } from "./config";
|
|
12
15
|
import type { MCPToolDetails } from "./tool-bridge";
|
|
@@ -14,8 +17,6 @@ import { DeferredMCPTool, MCPTool } from "./tool-bridge";
|
|
|
14
17
|
import type { MCPToolCache } from "./tool-cache";
|
|
15
18
|
import type { MCPServerConfig, MCPServerConnection, MCPToolDefinition } from "./types";
|
|
16
19
|
|
|
17
|
-
type SourceMeta = import("../capability/types").SourceMeta;
|
|
18
|
-
|
|
19
20
|
type ToolLoadResult = {
|
|
20
21
|
connection: MCPServerConnection;
|
|
21
22
|
serverTools: MCPToolDefinition[];
|
|
@@ -82,12 +83,20 @@ export class MCPManager {
|
|
|
82
83
|
#pendingConnections = new Map<string, Promise<MCPServerConnection>>();
|
|
83
84
|
#pendingToolLoads = new Map<string, Promise<ToolLoadResult>>();
|
|
84
85
|
#sources = new Map<string, SourceMeta>();
|
|
86
|
+
#authStorage: AuthStorage | null = null;
|
|
85
87
|
|
|
86
88
|
constructor(
|
|
87
89
|
private cwd: string,
|
|
88
90
|
private toolCache: MCPToolCache | null = null,
|
|
89
91
|
) {}
|
|
90
92
|
|
|
93
|
+
/**
|
|
94
|
+
* Set the auth storage for resolving OAuth credentials.
|
|
95
|
+
*/
|
|
96
|
+
setAuthStorage(authStorage: AuthStorage): void {
|
|
97
|
+
this.#authStorage = authStorage;
|
|
98
|
+
}
|
|
99
|
+
|
|
91
100
|
/**
|
|
92
101
|
* Discover and connect to all MCP servers from .mcp.json files.
|
|
93
102
|
* Returns tools and any connection errors.
|
|
@@ -154,8 +163,14 @@ export class MCPManager {
|
|
|
154
163
|
continue;
|
|
155
164
|
}
|
|
156
165
|
|
|
157
|
-
|
|
166
|
+
// Resolve auth config before connecting
|
|
167
|
+
const resolvedConfig = await this.#resolveAuthConfig(config);
|
|
168
|
+
|
|
169
|
+
const connectionPromise = connectToServer(name, resolvedConfig).then(
|
|
158
170
|
connection => {
|
|
171
|
+
// Store original config (without resolved tokens) to keep
|
|
172
|
+
// cache keys stable and avoid leaking rotating credentials.
|
|
173
|
+
connection.config = config;
|
|
159
174
|
if (sources[name]) {
|
|
160
175
|
connection._source = sources[name];
|
|
161
176
|
}
|
|
@@ -286,6 +301,15 @@ export class MCPManager {
|
|
|
286
301
|
return this.#connections.get(name);
|
|
287
302
|
}
|
|
288
303
|
|
|
304
|
+
/**
|
|
305
|
+
* Get current connection status for a server.
|
|
306
|
+
*/
|
|
307
|
+
getConnectionStatus(name: string): "connected" | "connecting" | "disconnected" {
|
|
308
|
+
if (this.#connections.has(name)) return "connected";
|
|
309
|
+
if (this.#pendingConnections.has(name) || this.#pendingToolLoads.has(name)) return "connecting";
|
|
310
|
+
return "disconnected";
|
|
311
|
+
}
|
|
312
|
+
|
|
289
313
|
/**
|
|
290
314
|
* Get the source metadata for a server.
|
|
291
315
|
*/
|
|
@@ -304,6 +328,13 @@ export class MCPManager {
|
|
|
304
328
|
throw new Error(`MCP server not connected: ${name}`);
|
|
305
329
|
}
|
|
306
330
|
|
|
331
|
+
/**
|
|
332
|
+
* Resolve auth and shell-command substitutions in config before connecting.
|
|
333
|
+
*/
|
|
334
|
+
async prepareConfig(config: MCPServerConfig): Promise<MCPServerConfig> {
|
|
335
|
+
return this.#resolveAuthConfig(config);
|
|
336
|
+
}
|
|
337
|
+
|
|
307
338
|
/**
|
|
308
339
|
* Get all connected server names.
|
|
309
340
|
*/
|
|
@@ -369,6 +400,64 @@ export class MCPManager {
|
|
|
369
400
|
const promises = Array.from(this.#connections.keys()).map(name => this.refreshServerTools(name));
|
|
370
401
|
await Promise.allSettled(promises);
|
|
371
402
|
}
|
|
403
|
+
|
|
404
|
+
/**
|
|
405
|
+
* Resolve OAuth credentials and shell commands in config.
|
|
406
|
+
*/
|
|
407
|
+
async #resolveAuthConfig(config: MCPServerConfig): Promise<MCPServerConfig> {
|
|
408
|
+
let resolved: MCPServerConfig = { ...config };
|
|
409
|
+
|
|
410
|
+
const auth = config.auth;
|
|
411
|
+
if (auth?.type === "oauth" && auth.credentialId && this.#authStorage) {
|
|
412
|
+
const credentialId = auth.credentialId;
|
|
413
|
+
try {
|
|
414
|
+
const credential = this.#authStorage.get(credentialId);
|
|
415
|
+
if (credential?.type === "oauth") {
|
|
416
|
+
if (resolved.type === "http" || resolved.type === "sse") {
|
|
417
|
+
resolved = {
|
|
418
|
+
...resolved,
|
|
419
|
+
headers: {
|
|
420
|
+
...resolved.headers,
|
|
421
|
+
Authorization: `Bearer ${credential.access}`,
|
|
422
|
+
},
|
|
423
|
+
};
|
|
424
|
+
} else {
|
|
425
|
+
resolved = {
|
|
426
|
+
...resolved,
|
|
427
|
+
env: {
|
|
428
|
+
...resolved.env,
|
|
429
|
+
OAUTH_ACCESS_TOKEN: credential.access,
|
|
430
|
+
},
|
|
431
|
+
};
|
|
432
|
+
}
|
|
433
|
+
}
|
|
434
|
+
} catch (error) {
|
|
435
|
+
logger.warn("Failed to resolve OAuth credential", { credentialId, error });
|
|
436
|
+
}
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
if (resolved.type !== "http" && resolved.type !== "sse") {
|
|
440
|
+
if (resolved.env) {
|
|
441
|
+
const nextEnv: Record<string, string> = {};
|
|
442
|
+
for (const [key, value] of Object.entries(resolved.env)) {
|
|
443
|
+
const resolvedValue = await resolveConfigValue(value);
|
|
444
|
+
if (resolvedValue) nextEnv[key] = resolvedValue;
|
|
445
|
+
}
|
|
446
|
+
resolved = { ...resolved, env: nextEnv };
|
|
447
|
+
}
|
|
448
|
+
} else {
|
|
449
|
+
if (resolved.headers) {
|
|
450
|
+
const nextHeaders: Record<string, string> = {};
|
|
451
|
+
for (const [key, value] of Object.entries(resolved.headers)) {
|
|
452
|
+
const resolvedValue = await resolveConfigValue(value);
|
|
453
|
+
if (resolvedValue) nextHeaders[key] = resolvedValue;
|
|
454
|
+
}
|
|
455
|
+
resolved = { ...resolved, headers: nextHeaders };
|
|
456
|
+
}
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
return resolved;
|
|
460
|
+
}
|
|
372
461
|
}
|
|
373
462
|
|
|
374
463
|
/**
|
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* MCP OAuth Auto-Discovery
|
|
3
|
+
*
|
|
4
|
+
* Automatically detects OAuth requirements from MCP server responses
|
|
5
|
+
* and extracts authentication endpoints.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
export interface OAuthEndpoints {
|
|
9
|
+
authorizationUrl: string;
|
|
10
|
+
tokenUrl: string;
|
|
11
|
+
clientId?: string;
|
|
12
|
+
scopes?: string;
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
export interface AuthDetectionResult {
|
|
16
|
+
requiresAuth: boolean;
|
|
17
|
+
authType?: "oauth" | "apikey" | "unknown";
|
|
18
|
+
oauth?: OAuthEndpoints;
|
|
19
|
+
message?: string;
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
/**
|
|
23
|
+
* Detect if an error indicates authentication is required.
|
|
24
|
+
* Checks for common auth error patterns.
|
|
25
|
+
*/
|
|
26
|
+
export function detectAuthError(error: Error): boolean {
|
|
27
|
+
const errorMsg = error.message.toLowerCase();
|
|
28
|
+
|
|
29
|
+
// Check for HTTP auth status codes
|
|
30
|
+
if (
|
|
31
|
+
errorMsg.includes("401") ||
|
|
32
|
+
errorMsg.includes("403") ||
|
|
33
|
+
errorMsg.includes("unauthorized") ||
|
|
34
|
+
errorMsg.includes("forbidden") ||
|
|
35
|
+
errorMsg.includes("authentication required") ||
|
|
36
|
+
errorMsg.includes("authentication failed")
|
|
37
|
+
) {
|
|
38
|
+
return true;
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
return false;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
/**
|
|
45
|
+
* Extract OAuth endpoints from error response.
|
|
46
|
+
* Looks for WWW-Authenticate header format or JSON error bodies.
|
|
47
|
+
*/
|
|
48
|
+
export function extractOAuthEndpoints(error: Error): OAuthEndpoints | null {
|
|
49
|
+
const errorMsg = error.message;
|
|
50
|
+
|
|
51
|
+
const readEndpointsFromObject = (obj: Record<string, unknown>): OAuthEndpoints | null => {
|
|
52
|
+
const authorizationUrl =
|
|
53
|
+
(obj.authorization_url as string | undefined) ||
|
|
54
|
+
(obj.authorizationUrl as string | undefined) ||
|
|
55
|
+
(obj.authorization_endpoint as string | undefined) ||
|
|
56
|
+
(obj.authorizationEndpoint as string | undefined) ||
|
|
57
|
+
(obj.authorization_uri as string | undefined) ||
|
|
58
|
+
(obj.authorizationUri as string | undefined);
|
|
59
|
+
const tokenUrl =
|
|
60
|
+
(obj.token_url as string | undefined) ||
|
|
61
|
+
(obj.tokenUrl as string | undefined) ||
|
|
62
|
+
(obj.token_endpoint as string | undefined) ||
|
|
63
|
+
(obj.tokenEndpoint as string | undefined) ||
|
|
64
|
+
(obj.token_uri as string | undefined) ||
|
|
65
|
+
(obj.tokenUri as string | undefined);
|
|
66
|
+
|
|
67
|
+
if (!authorizationUrl || !tokenUrl) return null;
|
|
68
|
+
|
|
69
|
+
const scopeFromArray = Array.isArray(obj.scopes_supported)
|
|
70
|
+
? (obj.scopes_supported as unknown[]).filter(v => typeof v === "string").join(" ")
|
|
71
|
+
: undefined;
|
|
72
|
+
const scopes = (obj.scopes as string | undefined) || (obj.scope as string | undefined) || scopeFromArray;
|
|
73
|
+
const clientId =
|
|
74
|
+
(obj.client_id as string | undefined) ||
|
|
75
|
+
(obj.clientId as string | undefined) ||
|
|
76
|
+
(obj.default_client_id as string | undefined) ||
|
|
77
|
+
(obj.public_client_id as string | undefined);
|
|
78
|
+
|
|
79
|
+
return { authorizationUrl, tokenUrl, clientId, scopes };
|
|
80
|
+
};
|
|
81
|
+
|
|
82
|
+
const clientIdFromAuthUrl = (authorizationUrl: string): string | undefined => {
|
|
83
|
+
try {
|
|
84
|
+
return new URL(authorizationUrl).searchParams.get("client_id") ?? undefined;
|
|
85
|
+
} catch {
|
|
86
|
+
return undefined;
|
|
87
|
+
}
|
|
88
|
+
};
|
|
89
|
+
|
|
90
|
+
const scopeFromAuthUrl = (authorizationUrl: string): string | undefined => {
|
|
91
|
+
try {
|
|
92
|
+
return new URL(authorizationUrl).searchParams.get("scope") ?? undefined;
|
|
93
|
+
} catch {
|
|
94
|
+
return undefined;
|
|
95
|
+
}
|
|
96
|
+
};
|
|
97
|
+
|
|
98
|
+
try {
|
|
99
|
+
// Try to parse as JSON error response
|
|
100
|
+
// Many MCP servers return JSON with OAuth endpoints in error body
|
|
101
|
+
const jsonMatch = errorMsg.match(/\{[\s\S]*\}/);
|
|
102
|
+
if (jsonMatch) {
|
|
103
|
+
const errorBody = JSON.parse(jsonMatch[0]) as Record<string, unknown>;
|
|
104
|
+
|
|
105
|
+
// Check for OAuth endpoints in error body
|
|
106
|
+
if (errorBody.oauth || errorBody.authorization || errorBody.auth) {
|
|
107
|
+
const oauthData = (errorBody.oauth || errorBody.authorization || errorBody.auth) as Record<string, unknown>;
|
|
108
|
+
const endpoints = readEndpointsFromObject(oauthData);
|
|
109
|
+
if (endpoints) {
|
|
110
|
+
return {
|
|
111
|
+
...endpoints,
|
|
112
|
+
clientId: endpoints.clientId || clientIdFromAuthUrl(endpoints.authorizationUrl),
|
|
113
|
+
scopes: endpoints.scopes || scopeFromAuthUrl(endpoints.authorizationUrl),
|
|
114
|
+
};
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
const topLevelEndpoints = readEndpointsFromObject(errorBody);
|
|
119
|
+
if (topLevelEndpoints) {
|
|
120
|
+
return {
|
|
121
|
+
...topLevelEndpoints,
|
|
122
|
+
clientId: topLevelEndpoints.clientId || clientIdFromAuthUrl(topLevelEndpoints.authorizationUrl),
|
|
123
|
+
scopes: topLevelEndpoints.scopes || scopeFromAuthUrl(topLevelEndpoints.authorizationUrl),
|
|
124
|
+
};
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
} catch {
|
|
128
|
+
// Not JSON, continue with other detection methods
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
const challengeEntries = Array.from(errorMsg.matchAll(/([a-zA-Z_][a-zA-Z0-9_-]*)="([^"]+)"/g));
|
|
132
|
+
if (challengeEntries.length > 0) {
|
|
133
|
+
const challengeValues = new Map<string, string>();
|
|
134
|
+
for (const [, rawKey, value] of challengeEntries) {
|
|
135
|
+
challengeValues.set(rawKey.toLowerCase(), value);
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
const authorizationUrl =
|
|
139
|
+
challengeValues.get("authorization_uri") ||
|
|
140
|
+
challengeValues.get("authorization_url") ||
|
|
141
|
+
challengeValues.get("authorization_endpoint") ||
|
|
142
|
+
challengeValues.get("authorize_url") ||
|
|
143
|
+
challengeValues.get("realm");
|
|
144
|
+
const tokenUrl =
|
|
145
|
+
challengeValues.get("token_url") || challengeValues.get("token_uri") || challengeValues.get("token_endpoint");
|
|
146
|
+
|
|
147
|
+
if (authorizationUrl && tokenUrl) {
|
|
148
|
+
return {
|
|
149
|
+
authorizationUrl,
|
|
150
|
+
tokenUrl,
|
|
151
|
+
clientId: challengeValues.get("client_id") || clientIdFromAuthUrl(authorizationUrl),
|
|
152
|
+
scopes: challengeValues.get("scope") || challengeValues.get("scopes") || scopeFromAuthUrl(authorizationUrl),
|
|
153
|
+
};
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
// Try to extract from WWW-Authenticate header format
|
|
158
|
+
// Example: Bearer realm="https://auth.example.com/oauth/authorize" token_url="https://auth.example.com/oauth/token"
|
|
159
|
+
const wwwAuthMatch = errorMsg.match(/realm="([^"]+)".*token_url="([^"]+)"/);
|
|
160
|
+
if (wwwAuthMatch) {
|
|
161
|
+
return {
|
|
162
|
+
authorizationUrl: wwwAuthMatch[1],
|
|
163
|
+
tokenUrl: wwwAuthMatch[2],
|
|
164
|
+
clientId: clientIdFromAuthUrl(wwwAuthMatch[1]),
|
|
165
|
+
scopes: scopeFromAuthUrl(wwwAuthMatch[1]),
|
|
166
|
+
};
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
return null;
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
/**
|
|
173
|
+
* Analyze an error to determine authentication requirements.
|
|
174
|
+
* Returns structured info about what auth is needed.
|
|
175
|
+
*/
|
|
176
|
+
export function analyzeAuthError(error: Error): AuthDetectionResult {
|
|
177
|
+
if (!detectAuthError(error)) {
|
|
178
|
+
return { requiresAuth: false };
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
// Try to extract OAuth endpoints
|
|
182
|
+
const oauth = extractOAuthEndpoints(error);
|
|
183
|
+
|
|
184
|
+
if (oauth) {
|
|
185
|
+
return {
|
|
186
|
+
requiresAuth: true,
|
|
187
|
+
authType: "oauth",
|
|
188
|
+
oauth,
|
|
189
|
+
message: "Server requires OAuth authentication. Launching authorization flow...",
|
|
190
|
+
};
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
// Check if it might be API key based
|
|
194
|
+
const errorMsg = error.message.toLowerCase();
|
|
195
|
+
if (
|
|
196
|
+
errorMsg.includes("api key") ||
|
|
197
|
+
errorMsg.includes("api_key") ||
|
|
198
|
+
errorMsg.includes("token") ||
|
|
199
|
+
errorMsg.includes("bearer")
|
|
200
|
+
) {
|
|
201
|
+
return {
|
|
202
|
+
requiresAuth: true,
|
|
203
|
+
authType: "apikey",
|
|
204
|
+
message: "Server requires API key authentication.",
|
|
205
|
+
};
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
// Unknown auth type
|
|
209
|
+
return {
|
|
210
|
+
requiresAuth: true,
|
|
211
|
+
authType: "unknown",
|
|
212
|
+
message: "Server requires authentication but type could not be determined.",
|
|
213
|
+
};
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
/**
|
|
217
|
+
* Try to discover OAuth endpoints by querying the server's well-known endpoints.
|
|
218
|
+
* This is a fallback when error responses don't include OAuth metadata.
|
|
219
|
+
*/
|
|
220
|
+
export async function discoverOAuthEndpoints(serverUrl: string): Promise<OAuthEndpoints | null> {
|
|
221
|
+
const wellKnownPaths = [
|
|
222
|
+
"/.well-known/oauth-authorization-server",
|
|
223
|
+
"/.well-known/openid-configuration",
|
|
224
|
+
"/oauth/metadata",
|
|
225
|
+
"/.mcp/auth",
|
|
226
|
+
"/authorize", // Some MCP servers expose OAuth config here
|
|
227
|
+
];
|
|
228
|
+
|
|
229
|
+
for (const path of wellKnownPaths) {
|
|
230
|
+
try {
|
|
231
|
+
const url = new URL(path, serverUrl);
|
|
232
|
+
const response = await fetch(url.toString(), {
|
|
233
|
+
method: "GET",
|
|
234
|
+
headers: { Accept: "application/json" },
|
|
235
|
+
});
|
|
236
|
+
|
|
237
|
+
if (response.ok) {
|
|
238
|
+
const metadata = await response.json();
|
|
239
|
+
|
|
240
|
+
// Check for standard OAuth discovery format
|
|
241
|
+
if (metadata.authorization_endpoint && metadata.token_endpoint) {
|
|
242
|
+
return {
|
|
243
|
+
authorizationUrl: metadata.authorization_endpoint,
|
|
244
|
+
tokenUrl: metadata.token_endpoint,
|
|
245
|
+
clientId:
|
|
246
|
+
metadata.client_id || metadata.clientId || metadata.default_client_id || metadata.public_client_id,
|
|
247
|
+
scopes: metadata.scopes_supported?.join(" ") || metadata.scopes || metadata.scope,
|
|
248
|
+
};
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
// Check for MCP-specific format
|
|
252
|
+
if (metadata.oauth || metadata.authorization || metadata.auth) {
|
|
253
|
+
const oauthData = metadata.oauth || metadata.authorization || metadata.auth;
|
|
254
|
+
if (oauthData.authorization_url && oauthData.token_url) {
|
|
255
|
+
return {
|
|
256
|
+
authorizationUrl: oauthData.authorization_url || oauthData.authorizationUrl,
|
|
257
|
+
tokenUrl: oauthData.token_url || oauthData.tokenUrl,
|
|
258
|
+
clientId:
|
|
259
|
+
oauthData.client_id ||
|
|
260
|
+
oauthData.clientId ||
|
|
261
|
+
oauthData.default_client_id ||
|
|
262
|
+
oauthData.public_client_id,
|
|
263
|
+
scopes: oauthData.scopes || oauthData.scope,
|
|
264
|
+
};
|
|
265
|
+
}
|
|
266
|
+
}
|
|
267
|
+
}
|
|
268
|
+
} catch {
|
|
269
|
+
// Ignore errors, try next path
|
|
270
|
+
}
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
return null;
|
|
274
|
+
}
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Generic OAuth flow for MCP servers.
|
|
3
|
+
*
|
|
4
|
+
* Allows users to authenticate with any OAuth-compatible MCP server
|
|
5
|
+
* by providing authorization URL, token URL, and client credentials.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
import type { OAuthController, OAuthCredentials } from "@oh-my-pi/pi-ai";
|
|
9
|
+
import { OAuthCallbackFlow } from "@oh-my-pi/pi-ai/utils/oauth/callback-server";
|
|
10
|
+
|
|
11
|
+
const DEFAULT_PORT = 3000;
|
|
12
|
+
const CALLBACK_PATH = "/callback";
|
|
13
|
+
|
|
14
|
+
export interface MCPOAuthConfig {
|
|
15
|
+
/** Authorization endpoint URL */
|
|
16
|
+
authorizationUrl: string;
|
|
17
|
+
/** Token endpoint URL */
|
|
18
|
+
tokenUrl: string;
|
|
19
|
+
/** Client ID (optional when already embedded in authorization URL) */
|
|
20
|
+
clientId?: string;
|
|
21
|
+
/** Client secret (optional for PKCE flows) */
|
|
22
|
+
clientSecret?: string;
|
|
23
|
+
/** OAuth scopes (space-separated) */
|
|
24
|
+
scopes?: string;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
/**
|
|
28
|
+
* Generic OAuth flow for MCP servers.
|
|
29
|
+
* Supports standard OAuth 2.0 authorization code flow with PKCE.
|
|
30
|
+
*/
|
|
31
|
+
export class MCPOAuthFlow extends OAuthCallbackFlow {
|
|
32
|
+
#resolvedClientId?: string;
|
|
33
|
+
#registeredClientSecret?: string;
|
|
34
|
+
#codeVerifier?: string;
|
|
35
|
+
|
|
36
|
+
constructor(
|
|
37
|
+
private config: MCPOAuthConfig,
|
|
38
|
+
ctrl: OAuthController,
|
|
39
|
+
) {
|
|
40
|
+
super(ctrl, DEFAULT_PORT, CALLBACK_PATH);
|
|
41
|
+
this.#resolvedClientId = this.#resolveClientId(config);
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
async generateAuthUrl(state: string, redirectUri: string): Promise<{ url: string; instructions?: string }> {
|
|
45
|
+
if (!this.#resolvedClientId) {
|
|
46
|
+
await this.#tryRegisterClient(redirectUri);
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
const authUrl = new URL(this.config.authorizationUrl);
|
|
50
|
+
const params = authUrl.searchParams;
|
|
51
|
+
|
|
52
|
+
if (!params.get("response_type")) {
|
|
53
|
+
params.set("response_type", "code");
|
|
54
|
+
}
|
|
55
|
+
if (this.#resolvedClientId && !params.get("client_id")) {
|
|
56
|
+
params.set("client_id", this.#resolvedClientId);
|
|
57
|
+
}
|
|
58
|
+
if (this.config.scopes && !params.get("scope")) {
|
|
59
|
+
params.set("scope", this.config.scopes);
|
|
60
|
+
}
|
|
61
|
+
params.set("redirect_uri", redirectUri);
|
|
62
|
+
params.set("state", state);
|
|
63
|
+
|
|
64
|
+
// Add PKCE challenge (some providers require it)
|
|
65
|
+
const codeVerifier = this.#generateCodeVerifier();
|
|
66
|
+
const codeChallenge = await this.#generateCodeChallenge(codeVerifier);
|
|
67
|
+
params.set("code_challenge", codeChallenge);
|
|
68
|
+
params.set("code_challenge_method", "S256");
|
|
69
|
+
|
|
70
|
+
// Store code verifier for token exchange
|
|
71
|
+
this.#codeVerifier = codeVerifier;
|
|
72
|
+
|
|
73
|
+
return { url: authUrl.toString() };
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
async exchangeToken(code: string, _state: string, redirectUri: string): Promise<OAuthCredentials> {
|
|
77
|
+
const params = new URLSearchParams({
|
|
78
|
+
grant_type: "authorization_code",
|
|
79
|
+
code,
|
|
80
|
+
redirect_uri: redirectUri,
|
|
81
|
+
});
|
|
82
|
+
if (this.#resolvedClientId) {
|
|
83
|
+
params.set("client_id", this.#resolvedClientId);
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
// Add code verifier for PKCE
|
|
87
|
+
if (this.#codeVerifier) {
|
|
88
|
+
params.set("code_verifier", this.#codeVerifier);
|
|
89
|
+
}
|
|
90
|
+
this.#codeVerifier = undefined;
|
|
91
|
+
|
|
92
|
+
// Add client secret if provided
|
|
93
|
+
const clientSecret = this.config.clientSecret ?? this.#registeredClientSecret;
|
|
94
|
+
if (clientSecret) {
|
|
95
|
+
params.set("client_secret", clientSecret);
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
const response = await fetch(this.config.tokenUrl, {
|
|
99
|
+
method: "POST",
|
|
100
|
+
headers: {
|
|
101
|
+
"Content-Type": "application/x-www-form-urlencoded",
|
|
102
|
+
},
|
|
103
|
+
body: params.toString(),
|
|
104
|
+
});
|
|
105
|
+
|
|
106
|
+
if (!response.ok) {
|
|
107
|
+
const errorText = await response.text();
|
|
108
|
+
throw new Error(`Token exchange failed: ${response.status} ${errorText}`);
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
const data = (await response.json()) as {
|
|
112
|
+
access_token: string;
|
|
113
|
+
refresh_token?: string;
|
|
114
|
+
expires_in?: number;
|
|
115
|
+
token_type?: string;
|
|
116
|
+
};
|
|
117
|
+
|
|
118
|
+
// Calculate expiry timestamp
|
|
119
|
+
const expiresIn = data.expires_in ?? 3600; // Default to 1 hour
|
|
120
|
+
const expires = Date.now() + expiresIn * 1000;
|
|
121
|
+
|
|
122
|
+
return {
|
|
123
|
+
access: data.access_token,
|
|
124
|
+
refresh: data.refresh_token ?? "",
|
|
125
|
+
expires,
|
|
126
|
+
};
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
/**
|
|
130
|
+
* Generate PKCE code verifier (random string).
|
|
131
|
+
*/
|
|
132
|
+
#generateCodeVerifier(): string {
|
|
133
|
+
const bytes = new Uint8Array(32);
|
|
134
|
+
crypto.getRandomValues(bytes);
|
|
135
|
+
return this.#base64UrlEncode(bytes);
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
/**
|
|
139
|
+
* Generate PKCE code challenge from verifier.
|
|
140
|
+
*/
|
|
141
|
+
async #generateCodeChallenge(verifier: string): Promise<string> {
|
|
142
|
+
const encoder = new TextEncoder();
|
|
143
|
+
const data = encoder.encode(verifier);
|
|
144
|
+
const hash = await crypto.subtle.digest("SHA-256", data);
|
|
145
|
+
return this.#base64UrlEncode(new Uint8Array(hash));
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
/**
|
|
149
|
+
* Base64 URL encode (without padding).
|
|
150
|
+
*/
|
|
151
|
+
#base64UrlEncode(bytes: Uint8Array): string {
|
|
152
|
+
const base64 = btoa(String.fromCharCode(...bytes));
|
|
153
|
+
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=/g, "");
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
#resolveClientId(config: MCPOAuthConfig): string | undefined {
|
|
157
|
+
const fromConfig = config.clientId?.trim();
|
|
158
|
+
if (fromConfig) return fromConfig;
|
|
159
|
+
|
|
160
|
+
try {
|
|
161
|
+
return new URL(config.authorizationUrl).searchParams.get("client_id") ?? undefined;
|
|
162
|
+
} catch {
|
|
163
|
+
return undefined;
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
/**
|
|
168
|
+
* Try OAuth dynamic client registration when provider requires a client_id.
|
|
169
|
+
*/
|
|
170
|
+
async #tryRegisterClient(redirectUri: string): Promise<void> {
|
|
171
|
+
const registrationEndpoint = await this.#resolveRegistrationEndpoint();
|
|
172
|
+
if (!registrationEndpoint) return;
|
|
173
|
+
|
|
174
|
+
try {
|
|
175
|
+
const response = await fetch(registrationEndpoint, {
|
|
176
|
+
method: "POST",
|
|
177
|
+
headers: {
|
|
178
|
+
"Content-Type": "application/json",
|
|
179
|
+
Accept: "application/json",
|
|
180
|
+
},
|
|
181
|
+
body: JSON.stringify({
|
|
182
|
+
client_name: "oh-my-pi MCP",
|
|
183
|
+
redirect_uris: [redirectUri],
|
|
184
|
+
grant_types: ["authorization_code", "refresh_token"],
|
|
185
|
+
response_types: ["code"],
|
|
186
|
+
token_endpoint_auth_method: "none",
|
|
187
|
+
application_type: "native",
|
|
188
|
+
}),
|
|
189
|
+
});
|
|
190
|
+
|
|
191
|
+
if (!response.ok) return;
|
|
192
|
+
|
|
193
|
+
const data = (await response.json()) as {
|
|
194
|
+
client_id?: string;
|
|
195
|
+
client_secret?: string;
|
|
196
|
+
};
|
|
197
|
+
|
|
198
|
+
if (data.client_id && data.client_id.trim() !== "") {
|
|
199
|
+
this.#resolvedClientId = data.client_id;
|
|
200
|
+
}
|
|
201
|
+
if (data.client_secret && data.client_secret.trim() !== "") {
|
|
202
|
+
this.#registeredClientSecret = data.client_secret;
|
|
203
|
+
}
|
|
204
|
+
} catch {
|
|
205
|
+
// Ignore registration failures and continue without client registration.
|
|
206
|
+
}
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
async #resolveRegistrationEndpoint(): Promise<string | null> {
|
|
210
|
+
try {
|
|
211
|
+
const authorizationEndpoint = new URL(this.config.authorizationUrl);
|
|
212
|
+
const metadataUrl = new URL("/.well-known/oauth-authorization-server", authorizationEndpoint.origin);
|
|
213
|
+
const response = await fetch(metadataUrl.toString(), {
|
|
214
|
+
method: "GET",
|
|
215
|
+
headers: { Accept: "application/json" },
|
|
216
|
+
});
|
|
217
|
+
|
|
218
|
+
if (!response.ok) return null;
|
|
219
|
+
const metadata = (await response.json()) as { registration_endpoint?: string };
|
|
220
|
+
if (metadata.registration_endpoint && metadata.registration_endpoint.trim() !== "") {
|
|
221
|
+
return metadata.registration_endpoint;
|
|
222
|
+
}
|
|
223
|
+
} catch {
|
|
224
|
+
// Ignore metadata discovery failures.
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
return null;
|
|
228
|
+
}
|
|
229
|
+
}
|
package/src/mcp/tool-bridge.ts
CHANGED
|
@@ -219,8 +219,8 @@ export class DeferredMCPTool implements CustomTool<TSchema, MCPToolDetails> {
|
|
|
219
219
|
readonly mcpToolName: string;
|
|
220
220
|
/** Server name */
|
|
221
221
|
readonly mcpServerName: string;
|
|
222
|
-
|
|
223
|
-
|
|
222
|
+
readonly #fallbackProvider: string | undefined;
|
|
223
|
+
readonly #fallbackProviderName: string | undefined;
|
|
224
224
|
|
|
225
225
|
/** Create DeferredMCPTool instances for all tools from an MCP server */
|
|
226
226
|
static fromTools(
|
|
@@ -244,8 +244,8 @@ export class DeferredMCPTool implements CustomTool<TSchema, MCPToolDetails> {
|
|
|
244
244
|
this.parameters = convertSchema(tool.inputSchema);
|
|
245
245
|
this.mcpToolName = tool.name;
|
|
246
246
|
this.mcpServerName = serverName;
|
|
247
|
-
this
|
|
248
|
-
this
|
|
247
|
+
this.#fallbackProvider = source?.provider;
|
|
248
|
+
this.#fallbackProviderName = source?.providerName;
|
|
249
249
|
}
|
|
250
250
|
|
|
251
251
|
renderCall(args: unknown, theme: Theme) {
|
|
@@ -273,8 +273,8 @@ export class DeferredMCPTool implements CustomTool<TSchema, MCPToolDetails> {
|
|
|
273
273
|
mcpToolName: this.tool.name,
|
|
274
274
|
isError: result.isError,
|
|
275
275
|
rawContent: result.content,
|
|
276
|
-
provider: connection._source?.provider ?? this
|
|
277
|
-
providerName: connection._source?.providerName ?? this
|
|
276
|
+
provider: connection._source?.provider ?? this.#fallbackProvider,
|
|
277
|
+
providerName: connection._source?.providerName ?? this.#fallbackProviderName,
|
|
278
278
|
};
|
|
279
279
|
|
|
280
280
|
if (result.isError) {
|
|
@@ -296,8 +296,8 @@ export class DeferredMCPTool implements CustomTool<TSchema, MCPToolDetails> {
|
|
|
296
296
|
serverName: this.serverName,
|
|
297
297
|
mcpToolName: this.tool.name,
|
|
298
298
|
isError: true,
|
|
299
|
-
provider: this
|
|
300
|
-
providerName: this
|
|
299
|
+
provider: this.#fallbackProvider,
|
|
300
|
+
providerName: this.#fallbackProviderName,
|
|
301
301
|
},
|
|
302
302
|
};
|
|
303
303
|
}
|