@oh-my-pi/pi-coding-agent 3.34.0 → 3.36.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +43 -1
- package/README.md +7 -2
- package/package.json +5 -5
- package/src/core/agent-session.ts +202 -33
- package/src/core/auth-storage.ts +293 -28
- package/src/core/model-registry.ts +7 -8
- package/src/core/sdk.ts +12 -6
- package/src/core/session-manager.ts +15 -0
- package/src/core/settings-manager.ts +20 -6
- package/src/core/system-prompt.ts +1 -0
- package/src/core/title-generator.ts +3 -1
- package/src/core/tools/bash.ts +11 -3
- package/src/core/tools/calculator.ts +500 -0
- package/src/core/tools/edit.ts +1 -0
- package/src/core/tools/grep.ts +1 -1
- package/src/core/tools/index.test.ts +2 -0
- package/src/core/tools/index.ts +5 -0
- package/src/core/tools/renderers.ts +3 -0
- package/src/core/tools/task/index.ts +10 -1
- package/src/core/tools/task/model-resolver.ts +5 -4
- package/src/core/tools/web-search/auth.ts +13 -5
- package/src/core/tools/web-search/types.ts +11 -7
- package/src/main.ts +3 -0
- package/src/modes/interactive/components/oauth-selector.ts +1 -2
- package/src/modes/interactive/components/tool-execution.ts +15 -12
- package/src/modes/interactive/interactive-mode.ts +49 -13
- package/src/prompts/tools/ask.md +11 -5
- package/src/prompts/tools/bash.md +1 -0
- package/src/prompts/tools/calculator.md +8 -0
package/src/core/auth-storage.ts
CHANGED
|
@@ -3,7 +3,17 @@
|
|
|
3
3
|
* Handles loading, saving, and refreshing credentials from auth.json.
|
|
4
4
|
*/
|
|
5
5
|
|
|
6
|
-
import {
|
|
6
|
+
import {
|
|
7
|
+
chmodSync,
|
|
8
|
+
closeSync,
|
|
9
|
+
existsSync,
|
|
10
|
+
openSync,
|
|
11
|
+
readFileSync,
|
|
12
|
+
renameSync,
|
|
13
|
+
statSync,
|
|
14
|
+
unlinkSync,
|
|
15
|
+
writeFileSync,
|
|
16
|
+
} from "node:fs";
|
|
7
17
|
import { dirname } from "node:path";
|
|
8
18
|
import {
|
|
9
19
|
getEnvApiKey,
|
|
@@ -29,15 +39,26 @@ export type OAuthCredential = {
|
|
|
29
39
|
|
|
30
40
|
export type AuthCredential = ApiKeyCredential | OAuthCredential;
|
|
31
41
|
|
|
32
|
-
export type
|
|
42
|
+
export type AuthCredentialEntry = AuthCredential | AuthCredential[];
|
|
43
|
+
|
|
44
|
+
export type AuthStorageData = Record<string, AuthCredentialEntry>;
|
|
33
45
|
|
|
34
46
|
/**
|
|
35
47
|
* Credential storage backed by a JSON file.
|
|
36
48
|
* Reads from multiple fallback paths, writes to primary path.
|
|
37
49
|
*/
|
|
38
50
|
export class AuthStorage {
|
|
51
|
+
// File locking configuration for concurrent access protection
|
|
52
|
+
private static readonly lockRetryDelayMs = 50; // Polling interval when waiting for lock
|
|
53
|
+
private static readonly lockTimeoutMs = 5000; // Max wait time before failing
|
|
54
|
+
private static readonly lockStaleMs = 30000; // Age threshold for auto-removing orphaned locks
|
|
55
|
+
|
|
39
56
|
private data: AuthStorageData = {};
|
|
40
57
|
private runtimeOverrides: Map<string, string> = new Map();
|
|
58
|
+
/** Tracks next credential index per provider:type key for round-robin distribution */
|
|
59
|
+
private providerRoundRobinIndex: Map<string, number> = new Map();
|
|
60
|
+
/** Maps provider:type -> sessionId -> credentialIndex for session-sticky credential assignment */
|
|
61
|
+
private sessionCredentialIndexes: Map<string, Map<string, number>> = new Map();
|
|
41
62
|
private fallbackResolver?: (provider: string) => string | undefined;
|
|
42
63
|
|
|
43
64
|
/**
|
|
@@ -105,24 +126,244 @@ export class AuthStorage {
|
|
|
105
126
|
* Save credentials to disk.
|
|
106
127
|
*/
|
|
107
128
|
private async save(): Promise<void> {
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
129
|
+
const lockFd = await this.acquireLock();
|
|
130
|
+
const tempPath = this.getTempPath();
|
|
131
|
+
|
|
132
|
+
try {
|
|
133
|
+
writeFileSync(tempPath, JSON.stringify(this.data, null, 2), { mode: 0o600 });
|
|
134
|
+
renameSync(tempPath, this.authPath);
|
|
135
|
+
chmodSync(this.authPath, 0o600);
|
|
136
|
+
const dir = dirname(this.authPath);
|
|
137
|
+
chmodSync(dir, 0o700);
|
|
138
|
+
} finally {
|
|
139
|
+
this.safeUnlink(tempPath);
|
|
140
|
+
this.releaseLock(lockFd);
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
/** Returns the lock file path (auth.json.lock) */
|
|
145
|
+
private getLockPath(): string {
|
|
146
|
+
return `${this.authPath}.lock`;
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
/** Returns a unique temp file path using pid and timestamp to avoid collisions */
|
|
150
|
+
private getTempPath(): string {
|
|
151
|
+
return `${this.authPath}.tmp-${process.pid}-${Date.now()}`;
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
/** Checks if lock file is older than lockStaleMs (orphaned by crashed process) */
|
|
155
|
+
private isLockStale(lockPath: string): boolean {
|
|
156
|
+
try {
|
|
157
|
+
const stats = statSync(lockPath);
|
|
158
|
+
return Date.now() - stats.mtimeMs > AuthStorage.lockStaleMs;
|
|
159
|
+
} catch {
|
|
160
|
+
return false;
|
|
161
|
+
}
|
|
112
162
|
}
|
|
113
163
|
|
|
114
164
|
/**
|
|
115
|
-
*
|
|
165
|
+
* Acquires exclusive file lock using O_EXCL atomic create.
|
|
166
|
+
* Polls with exponential backoff, removes stale locks from crashed processes.
|
|
167
|
+
* @returns File descriptor for the lock (must be passed to releaseLock)
|
|
168
|
+
*/
|
|
169
|
+
private async acquireLock(): Promise<number> {
|
|
170
|
+
const lockPath = this.getLockPath();
|
|
171
|
+
const start = Date.now();
|
|
172
|
+
const timeoutMs = AuthStorage.lockTimeoutMs;
|
|
173
|
+
const retryDelayMs = AuthStorage.lockRetryDelayMs;
|
|
174
|
+
|
|
175
|
+
while (true) {
|
|
176
|
+
try {
|
|
177
|
+
// O_EXCL fails if file exists, providing atomic lock acquisition
|
|
178
|
+
return openSync(lockPath, "wx", 0o600);
|
|
179
|
+
} catch (error) {
|
|
180
|
+
const err = error as NodeJS.ErrnoException;
|
|
181
|
+
if (err.code !== "EEXIST") {
|
|
182
|
+
throw err;
|
|
183
|
+
}
|
|
184
|
+
if (this.isLockStale(lockPath)) {
|
|
185
|
+
this.safeUnlink(lockPath);
|
|
186
|
+
logger.warn("AuthStorage lock was stale, removing", { path: lockPath });
|
|
187
|
+
continue;
|
|
188
|
+
}
|
|
189
|
+
if (Date.now() - start > timeoutMs) {
|
|
190
|
+
throw new Error(`Timed out waiting for auth lock: ${lockPath}`);
|
|
191
|
+
}
|
|
192
|
+
await new Promise((resolve) => setTimeout(resolve, retryDelayMs));
|
|
193
|
+
}
|
|
194
|
+
}
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
/** Releases file lock by closing fd and removing lock file */
|
|
198
|
+
private releaseLock(lockFd: number): void {
|
|
199
|
+
const lockPath = this.getLockPath();
|
|
200
|
+
try {
|
|
201
|
+
closeSync(lockFd);
|
|
202
|
+
} catch (error) {
|
|
203
|
+
logger.warn("AuthStorage failed to close lock file", { error: String(error) });
|
|
204
|
+
}
|
|
205
|
+
this.safeUnlink(lockPath);
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
/** Removes file if it exists, ignoring ENOENT errors */
|
|
209
|
+
private safeUnlink(path: string): void {
|
|
210
|
+
try {
|
|
211
|
+
unlinkSync(path);
|
|
212
|
+
} catch (error) {
|
|
213
|
+
const err = error as NodeJS.ErrnoException;
|
|
214
|
+
if (err.code !== "ENOENT") {
|
|
215
|
+
logger.warn("AuthStorage failed to remove file", { path, error: String(error) });
|
|
216
|
+
}
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
/** Normalizes credential storage format: single credential becomes array of one */
|
|
221
|
+
private normalizeCredentialEntry(entry: AuthCredentialEntry | undefined): AuthCredential[] {
|
|
222
|
+
if (!entry) return [];
|
|
223
|
+
return Array.isArray(entry) ? entry : [entry];
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
/** Returns all credentials for a provider as an array */
|
|
227
|
+
private getCredentialsForProvider(provider: string): AuthCredential[] {
|
|
228
|
+
return this.normalizeCredentialEntry(this.data[provider]);
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
/** Composite key for round-robin tracking: "anthropic:oauth" or "openai:api_key" */
|
|
232
|
+
private getProviderTypeKey(provider: string, type: AuthCredential["type"]): string {
|
|
233
|
+
return `${provider}:${type}`;
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
/**
|
|
237
|
+
* Returns next index in round-robin sequence for load distribution.
|
|
238
|
+
* Increments stored counter and wraps at total.
|
|
239
|
+
*/
|
|
240
|
+
private getNextRoundRobinIndex(providerKey: string, total: number): number {
|
|
241
|
+
if (total <= 1) return 0;
|
|
242
|
+
const current = this.providerRoundRobinIndex.get(providerKey) ?? -1;
|
|
243
|
+
const next = (current + 1) % total;
|
|
244
|
+
this.providerRoundRobinIndex.set(providerKey, next);
|
|
245
|
+
return next;
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
/**
|
|
249
|
+
* Selects credential index with session affinity.
|
|
250
|
+
* Sessions reuse their assigned credential; new sessions get next round-robin index.
|
|
251
|
+
* This ensures a session always uses the same credential for consistency.
|
|
252
|
+
*/
|
|
253
|
+
private selectCredentialIndex(providerKey: string, sessionId: string | undefined, total: number): number {
|
|
254
|
+
if (total <= 1) return 0;
|
|
255
|
+
if (!sessionId) return 0;
|
|
256
|
+
|
|
257
|
+
const sessionMap = this.sessionCredentialIndexes.get(providerKey);
|
|
258
|
+
const existing = sessionMap?.get(sessionId);
|
|
259
|
+
if (existing !== undefined && existing < total) {
|
|
260
|
+
return existing;
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
// New session: assign next round-robin credential and cache the assignment
|
|
264
|
+
const next = this.getNextRoundRobinIndex(providerKey, total);
|
|
265
|
+
const updatedSessionMap = sessionMap ?? new Map<string, number>();
|
|
266
|
+
updatedSessionMap.set(sessionId, next);
|
|
267
|
+
this.sessionCredentialIndexes.set(providerKey, updatedSessionMap);
|
|
268
|
+
return next;
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
/**
|
|
272
|
+
* Selects a credential of the specified type for a provider.
|
|
273
|
+
* Returns both the credential and its index in the original array (for updates/removal).
|
|
274
|
+
* Uses session-sticky selection when multiple credentials exist.
|
|
275
|
+
*/
|
|
276
|
+
private selectCredentialByType<T extends AuthCredential["type"]>(
|
|
277
|
+
provider: string,
|
|
278
|
+
type: T,
|
|
279
|
+
sessionId?: string,
|
|
280
|
+
): { credential: Extract<AuthCredential, { type: T }>; index: number } | undefined {
|
|
281
|
+
const credentials = this.getCredentialsForProvider(provider)
|
|
282
|
+
.map((credential, index) => ({ credential, index }))
|
|
283
|
+
.filter(
|
|
284
|
+
(entry): entry is { credential: Extract<AuthCredential, { type: T }>; index: number } =>
|
|
285
|
+
entry.credential.type === type,
|
|
286
|
+
);
|
|
287
|
+
|
|
288
|
+
if (credentials.length === 0) return undefined;
|
|
289
|
+
if (credentials.length === 1) return credentials[0];
|
|
290
|
+
|
|
291
|
+
const providerKey = this.getProviderTypeKey(provider, type);
|
|
292
|
+
const selectedIndex = this.selectCredentialIndex(providerKey, sessionId, credentials.length);
|
|
293
|
+
return credentials[selectedIndex];
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
/**
|
|
297
|
+
* Clears round-robin and session assignment state for a provider.
|
|
298
|
+
* Called when credentials are added/removed to prevent stale index references.
|
|
299
|
+
*/
|
|
300
|
+
private resetProviderAssignments(provider: string): void {
|
|
301
|
+
for (const key of this.providerRoundRobinIndex.keys()) {
|
|
302
|
+
if (key.startsWith(`${provider}:`)) {
|
|
303
|
+
this.providerRoundRobinIndex.delete(key);
|
|
304
|
+
}
|
|
305
|
+
}
|
|
306
|
+
for (const key of this.sessionCredentialIndexes.keys()) {
|
|
307
|
+
if (key.startsWith(`${provider}:`)) {
|
|
308
|
+
this.sessionCredentialIndexes.delete(key);
|
|
309
|
+
}
|
|
310
|
+
}
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
/** Updates credential at index in-place (used for OAuth token refresh) */
|
|
314
|
+
private replaceCredentialAt(provider: string, index: number, credential: AuthCredential): void {
|
|
315
|
+
const entry = this.data[provider];
|
|
316
|
+
if (!entry) return;
|
|
317
|
+
|
|
318
|
+
if (Array.isArray(entry)) {
|
|
319
|
+
if (index >= 0 && index < entry.length) {
|
|
320
|
+
const updated = [...entry];
|
|
321
|
+
updated[index] = credential;
|
|
322
|
+
this.data[provider] = updated;
|
|
323
|
+
}
|
|
324
|
+
return;
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
if (index === 0) {
|
|
328
|
+
this.data[provider] = credential;
|
|
329
|
+
}
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
/**
|
|
333
|
+
* Removes credential at index (used when OAuth refresh fails).
|
|
334
|
+
* Cleans up provider entry if last credential removed.
|
|
335
|
+
*/
|
|
336
|
+
private removeCredentialAt(provider: string, index: number): void {
|
|
337
|
+
const entry = this.data[provider];
|
|
338
|
+
if (!entry) return;
|
|
339
|
+
|
|
340
|
+
if (Array.isArray(entry)) {
|
|
341
|
+
const updated = entry.filter((_value, idx) => idx !== index);
|
|
342
|
+
if (updated.length > 0) {
|
|
343
|
+
this.data[provider] = updated;
|
|
344
|
+
} else {
|
|
345
|
+
delete this.data[provider];
|
|
346
|
+
}
|
|
347
|
+
} else {
|
|
348
|
+
delete this.data[provider];
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
this.resetProviderAssignments(provider);
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
/**
|
|
355
|
+
* Get credential for a provider (first entry if multiple).
|
|
116
356
|
*/
|
|
117
357
|
get(provider: string): AuthCredential | undefined {
|
|
118
|
-
return this.
|
|
358
|
+
return this.getCredentialsForProvider(provider)[0];
|
|
119
359
|
}
|
|
120
360
|
|
|
121
361
|
/**
|
|
122
362
|
* Set credential for a provider.
|
|
123
363
|
*/
|
|
124
|
-
async set(provider: string, credential:
|
|
364
|
+
async set(provider: string, credential: AuthCredentialEntry): Promise<void> {
|
|
125
365
|
this.data[provider] = credential;
|
|
366
|
+
this.resetProviderAssignments(provider);
|
|
126
367
|
await this.save();
|
|
127
368
|
}
|
|
128
369
|
|
|
@@ -131,6 +372,7 @@ export class AuthStorage {
|
|
|
131
372
|
*/
|
|
132
373
|
async remove(provider: string): Promise<void> {
|
|
133
374
|
delete this.data[provider];
|
|
375
|
+
this.resetProviderAssignments(provider);
|
|
134
376
|
await this.save();
|
|
135
377
|
}
|
|
136
378
|
|
|
@@ -145,7 +387,7 @@ export class AuthStorage {
|
|
|
145
387
|
* Check if credentials exist for a provider in auth.json.
|
|
146
388
|
*/
|
|
147
389
|
has(provider: string): boolean {
|
|
148
|
-
return provider
|
|
390
|
+
return this.getCredentialsForProvider(provider).length > 0;
|
|
149
391
|
}
|
|
150
392
|
|
|
151
393
|
/**
|
|
@@ -154,14 +396,30 @@ export class AuthStorage {
|
|
|
154
396
|
*/
|
|
155
397
|
hasAuth(provider: string): boolean {
|
|
156
398
|
if (this.runtimeOverrides.has(provider)) return true;
|
|
157
|
-
if (this.
|
|
399
|
+
if (this.getCredentialsForProvider(provider).length > 0) return true;
|
|
158
400
|
if (getEnvApiKey(provider)) return true;
|
|
159
401
|
if (this.fallbackResolver?.(provider)) return true;
|
|
160
402
|
return false;
|
|
161
403
|
}
|
|
162
404
|
|
|
163
405
|
/**
|
|
164
|
-
*
|
|
406
|
+
* Check if OAuth credentials are configured for a provider.
|
|
407
|
+
*/
|
|
408
|
+
hasOAuth(provider: string): boolean {
|
|
409
|
+
return this.getCredentialsForProvider(provider).some((credential) => credential.type === "oauth");
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
/**
|
|
413
|
+
* Get OAuth credentials for a provider.
|
|
414
|
+
*/
|
|
415
|
+
getOAuthCredential(provider: string): OAuthCredential | undefined {
|
|
416
|
+
return this.getCredentialsForProvider(provider).find(
|
|
417
|
+
(credential): credential is OAuthCredential => credential.type === "oauth",
|
|
418
|
+
);
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
/**
|
|
422
|
+
* Get all credentials.
|
|
165
423
|
*/
|
|
166
424
|
getAll(): AuthStorageData {
|
|
167
425
|
return { ...this.data };
|
|
@@ -207,7 +465,14 @@ export class AuthStorage {
|
|
|
207
465
|
throw new Error(`Unknown OAuth provider: ${provider}`);
|
|
208
466
|
}
|
|
209
467
|
|
|
210
|
-
|
|
468
|
+
const newCredential: OAuthCredential = { type: "oauth", ...credentials };
|
|
469
|
+
const existing = this.getCredentialsForProvider(provider);
|
|
470
|
+
if (existing.length === 0) {
|
|
471
|
+
await this.set(provider, newCredential);
|
|
472
|
+
return;
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
await this.set(provider, [...existing, newCredential]);
|
|
211
476
|
}
|
|
212
477
|
|
|
213
478
|
/**
|
|
@@ -226,37 +491,37 @@ export class AuthStorage {
|
|
|
226
491
|
* 4. Environment variable
|
|
227
492
|
* 5. Fallback resolver (models.json custom providers)
|
|
228
493
|
*/
|
|
229
|
-
async getApiKey(provider: string): Promise<string | undefined> {
|
|
494
|
+
async getApiKey(provider: string, sessionId?: string): Promise<string | undefined> {
|
|
230
495
|
// Runtime override takes highest priority
|
|
231
496
|
const runtimeKey = this.runtimeOverrides.get(provider);
|
|
232
497
|
if (runtimeKey) {
|
|
233
498
|
return runtimeKey;
|
|
234
499
|
}
|
|
235
500
|
|
|
236
|
-
const
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
return cred.key;
|
|
501
|
+
const apiKeySelection = this.selectCredentialByType(provider, "api_key", sessionId);
|
|
502
|
+
if (apiKeySelection) {
|
|
503
|
+
return apiKeySelection.credential.key;
|
|
240
504
|
}
|
|
241
505
|
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
const oauthCreds: Record<string, OAuthCredentials> = {
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
oauthCreds[key] = value;
|
|
248
|
-
}
|
|
249
|
-
}
|
|
506
|
+
const oauthSelection = this.selectCredentialByType(provider, "oauth", sessionId);
|
|
507
|
+
if (oauthSelection) {
|
|
508
|
+
const oauthCreds: Record<string, OAuthCredentials> = {
|
|
509
|
+
[provider]: oauthSelection.credential,
|
|
510
|
+
};
|
|
250
511
|
|
|
251
512
|
try {
|
|
252
513
|
const result = await getOAuthApiKey(provider as OAuthProvider, oauthCreds);
|
|
253
514
|
if (result) {
|
|
254
|
-
this.
|
|
515
|
+
this.replaceCredentialAt(provider, oauthSelection.index, { type: "oauth", ...result.newCredentials });
|
|
255
516
|
await this.save();
|
|
256
517
|
return result.apiKey;
|
|
257
518
|
}
|
|
258
519
|
} catch {
|
|
259
|
-
|
|
520
|
+
this.removeCredentialAt(provider, oauthSelection.index);
|
|
521
|
+
await this.save();
|
|
522
|
+
if (this.getCredentialsForProvider(provider).some((credential) => credential.type === "oauth")) {
|
|
523
|
+
return this.getApiKey(provider, sessionId);
|
|
524
|
+
}
|
|
260
525
|
}
|
|
261
526
|
}
|
|
262
527
|
|
|
@@ -187,8 +187,8 @@ export class ModelRegistry {
|
|
|
187
187
|
const combined = [...builtInModels, ...customModels];
|
|
188
188
|
|
|
189
189
|
// Update github-copilot base URL based on OAuth credentials
|
|
190
|
-
const copilotCred = this.authStorage.
|
|
191
|
-
if (copilotCred
|
|
190
|
+
const copilotCred = this.authStorage.getOAuthCredential("github-copilot");
|
|
191
|
+
if (copilotCred) {
|
|
192
192
|
const domain = copilotCred.enterpriseUrl
|
|
193
193
|
? (normalizeDomain(copilotCred.enterpriseUrl) ?? undefined)
|
|
194
194
|
: undefined;
|
|
@@ -390,22 +390,21 @@ export class ModelRegistry {
|
|
|
390
390
|
/**
|
|
391
391
|
* Get API key for a model.
|
|
392
392
|
*/
|
|
393
|
-
async getApiKey(model: Model<Api
|
|
394
|
-
return this.authStorage.getApiKey(model.provider);
|
|
393
|
+
async getApiKey(model: Model<Api>, sessionId?: string): Promise<string | undefined> {
|
|
394
|
+
return this.authStorage.getApiKey(model.provider, sessionId);
|
|
395
395
|
}
|
|
396
396
|
|
|
397
397
|
/**
|
|
398
398
|
* Get API key for a provider (e.g., "openai").
|
|
399
399
|
*/
|
|
400
|
-
async getApiKeyForProvider(provider: string): Promise<string | undefined> {
|
|
401
|
-
return this.authStorage.getApiKey(provider);
|
|
400
|
+
async getApiKeyForProvider(provider: string, sessionId?: string): Promise<string | undefined> {
|
|
401
|
+
return this.authStorage.getApiKey(provider, sessionId);
|
|
402
402
|
}
|
|
403
403
|
|
|
404
404
|
/**
|
|
405
405
|
* Check if a model is using OAuth credentials (subscription).
|
|
406
406
|
*/
|
|
407
407
|
isUsingOAuth(model: Model<Api>): boolean {
|
|
408
|
-
|
|
409
|
-
return cred?.type === "oauth";
|
|
408
|
+
return this.authStorage.hasOAuth(model.provider);
|
|
410
409
|
}
|
|
411
410
|
}
|
package/src/core/sdk.ts
CHANGED
|
@@ -538,6 +538,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|
|
538
538
|
|
|
539
539
|
const sessionManager = options.sessionManager ?? SessionManager.create(cwd);
|
|
540
540
|
time("sessionManager");
|
|
541
|
+
const sessionId = sessionManager.getSessionId();
|
|
541
542
|
|
|
542
543
|
// Check if session has existing data to restore
|
|
543
544
|
const existingSession = sessionManager.buildSessionContext();
|
|
@@ -554,7 +555,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|
|
554
555
|
const parsedModel = parseModelString(defaultModelStr);
|
|
555
556
|
if (parsedModel) {
|
|
556
557
|
const restoredModel = modelRegistry.find(parsedModel.provider, parsedModel.id);
|
|
557
|
-
if (restoredModel && (await modelRegistry.getApiKey(restoredModel))) {
|
|
558
|
+
if (restoredModel && (await modelRegistry.getApiKey(restoredModel, sessionId))) {
|
|
558
559
|
model = restoredModel;
|
|
559
560
|
}
|
|
560
561
|
}
|
|
@@ -570,7 +571,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|
|
570
571
|
const parsedModel = parseModelString(settingsDefaultModel);
|
|
571
572
|
if (parsedModel) {
|
|
572
573
|
const settingsModel = modelRegistry.find(parsedModel.provider, parsedModel.id);
|
|
573
|
-
if (settingsModel && (await modelRegistry.getApiKey(settingsModel))) {
|
|
574
|
+
if (settingsModel && (await modelRegistry.getApiKey(settingsModel, sessionId))) {
|
|
574
575
|
model = settingsModel;
|
|
575
576
|
}
|
|
576
577
|
}
|
|
@@ -580,7 +581,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|
|
580
581
|
// Fall back to first available model with a valid API key
|
|
581
582
|
if (!model) {
|
|
582
583
|
for (const m of modelRegistry.getAll()) {
|
|
583
|
-
if (await modelRegistry.getApiKey(m)) {
|
|
584
|
+
if (await modelRegistry.getApiKey(m, sessionId)) {
|
|
584
585
|
model = m;
|
|
585
586
|
break;
|
|
586
587
|
}
|
|
@@ -633,6 +634,9 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|
|
633
634
|
const contextFiles = options.contextFiles ?? discoverContextFiles(cwd, agentDir);
|
|
634
635
|
time("discoverContextFiles");
|
|
635
636
|
|
|
637
|
+
let agent: Agent;
|
|
638
|
+
let session: AgentSession;
|
|
639
|
+
|
|
636
640
|
const toolSession: ToolSession = {
|
|
637
641
|
cwd,
|
|
638
642
|
hasUI: options.hasUI ?? false,
|
|
@@ -643,6 +647,10 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|
|
643
647
|
getSessionFile: () => sessionManager.getSessionFile() ?? null,
|
|
644
648
|
getSessionSpawns: () => options.spawns ?? "*",
|
|
645
649
|
getModelString: () => (hasExplicitModel && model ? formatModelString(model) : undefined),
|
|
650
|
+
getActiveModelString: () => {
|
|
651
|
+
const activeModel = agent?.state.model;
|
|
652
|
+
return activeModel ? formatModelString(activeModel) : undefined;
|
|
653
|
+
},
|
|
646
654
|
settings: settingsManager,
|
|
647
655
|
};
|
|
648
656
|
|
|
@@ -782,8 +790,6 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|
|
782
790
|
extensionRunner = new ExtensionRunner(extensionsResult.extensions, cwd, sessionManager, modelRegistry);
|
|
783
791
|
}
|
|
784
792
|
|
|
785
|
-
let agent: Agent;
|
|
786
|
-
let session: AgentSession;
|
|
787
793
|
const getSessionContext = () => ({
|
|
788
794
|
sessionManager,
|
|
789
795
|
modelRegistry,
|
|
@@ -916,7 +922,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|
|
916
922
|
if (!currentModel) {
|
|
917
923
|
throw new Error("No model selected");
|
|
918
924
|
}
|
|
919
|
-
const key = await modelRegistry.getApiKey(currentModel);
|
|
925
|
+
const key = await modelRegistry.getApiKey(currentModel, sessionId);
|
|
920
926
|
if (!key) {
|
|
921
927
|
throw new Error(`No API key found for provider "${currentModel.provider}"`);
|
|
922
928
|
}
|
|
@@ -1338,6 +1338,21 @@ export class SessionManager {
|
|
|
1338
1338
|
return this.leafId ? this.byId.get(this.leafId) : undefined;
|
|
1339
1339
|
}
|
|
1340
1340
|
|
|
1341
|
+
/**
|
|
1342
|
+
* Get the most recent model role from the current session path.
|
|
1343
|
+
* Returns undefined if no model change has been recorded.
|
|
1344
|
+
*/
|
|
1345
|
+
getLastModelChangeRole(): string | undefined {
|
|
1346
|
+
let current = this.getLeafEntry();
|
|
1347
|
+
while (current) {
|
|
1348
|
+
if (current.type === "model_change") {
|
|
1349
|
+
return current.role ?? "default";
|
|
1350
|
+
}
|
|
1351
|
+
current = current.parentId ? this.byId.get(current.parentId) : undefined;
|
|
1352
|
+
}
|
|
1353
|
+
return undefined;
|
|
1354
|
+
}
|
|
1355
|
+
|
|
1341
1356
|
getEntry(id: string): SessionEntry | undefined {
|
|
1342
1357
|
return this.byId.get(id);
|
|
1343
1358
|
}
|
|
@@ -371,7 +371,8 @@ export class SettingsManager {
|
|
|
371
371
|
private settingsPath: string | null;
|
|
372
372
|
private cwd: string | null;
|
|
373
373
|
private globalSettings: Settings;
|
|
374
|
-
private
|
|
374
|
+
private overrides: Settings;
|
|
375
|
+
private settings!: Settings;
|
|
375
376
|
private persist: boolean;
|
|
376
377
|
|
|
377
378
|
private constructor(settingsPath: string | null, cwd: string | null, initialSettings: Settings, persist: boolean) {
|
|
@@ -379,8 +380,8 @@ export class SettingsManager {
|
|
|
379
380
|
this.cwd = cwd;
|
|
380
381
|
this.persist = persist;
|
|
381
382
|
this.globalSettings = initialSettings;
|
|
382
|
-
|
|
383
|
-
this.
|
|
383
|
+
this.overrides = {};
|
|
384
|
+
this.rebuildSettings();
|
|
384
385
|
|
|
385
386
|
// Apply environment variables from settings
|
|
386
387
|
this.applyEnvironmentVariables();
|
|
@@ -474,9 +475,17 @@ export class SettingsManager {
|
|
|
474
475
|
return SettingsManager.migrateSettings(merged as Record<string, unknown>);
|
|
475
476
|
}
|
|
476
477
|
|
|
478
|
+
private rebuildSettings(projectSettings?: Settings): void {
|
|
479
|
+
const resolvedProjectSettings = projectSettings ?? this.loadProjectSettings();
|
|
480
|
+
this.settings = normalizeSettings(
|
|
481
|
+
deepMergeSettings(deepMergeSettings(this.globalSettings, resolvedProjectSettings), this.overrides),
|
|
482
|
+
);
|
|
483
|
+
}
|
|
484
|
+
|
|
477
485
|
/** Apply additional overrides on top of current settings */
|
|
478
486
|
applyOverrides(overrides: Partial<Settings>): void {
|
|
479
|
-
this.
|
|
487
|
+
this.overrides = deepMergeSettings(this.overrides, overrides);
|
|
488
|
+
this.rebuildSettings();
|
|
480
489
|
}
|
|
481
490
|
|
|
482
491
|
private save(): void {
|
|
@@ -491,9 +500,9 @@ export class SettingsManager {
|
|
|
491
500
|
// Save only global settings (project settings are read-only)
|
|
492
501
|
writeFileSync(this.settingsPath, JSON.stringify(this.globalSettings, null, 2), "utf-8");
|
|
493
502
|
|
|
494
|
-
// Re-merge project settings into active settings
|
|
503
|
+
// Re-merge project settings into active settings (preserve overrides)
|
|
495
504
|
const projectSettings = this.loadProjectSettings();
|
|
496
|
-
this.
|
|
505
|
+
this.rebuildSettings(projectSettings);
|
|
497
506
|
} catch (error) {
|
|
498
507
|
console.error(`Warning: Could not save settings file: ${error}`);
|
|
499
508
|
}
|
|
@@ -523,6 +532,11 @@ export class SettingsManager {
|
|
|
523
532
|
this.globalSettings.modelRoles = {};
|
|
524
533
|
}
|
|
525
534
|
this.globalSettings.modelRoles[role] = model;
|
|
535
|
+
|
|
536
|
+
if (this.overrides.modelRoles && this.overrides.modelRoles[role] !== undefined) {
|
|
537
|
+
this.overrides.modelRoles[role] = model;
|
|
538
|
+
}
|
|
539
|
+
|
|
526
540
|
this.save();
|
|
527
541
|
}
|
|
528
542
|
|
|
@@ -72,6 +72,7 @@ const toolDescriptions: Record<ToolName, string> = {
|
|
|
72
72
|
ask: "Ask user for input or clarification",
|
|
73
73
|
read: "Read file contents",
|
|
74
74
|
bash: "Execute bash commands (npm, docker, etc.)",
|
|
75
|
+
calc: "{ calculations: array of { expression: string, prefix: string, suffix: string } } Basic calculations.",
|
|
75
76
|
ssh: "Execute commands on remote hosts via SSH",
|
|
76
77
|
edit: "Make surgical edits to files (find exact text and replace)",
|
|
77
78
|
write: "Create or overwrite files",
|
|
@@ -68,11 +68,13 @@ export async function findTitleModel(registry: ModelRegistry, savedSmolModel?: s
|
|
|
68
68
|
* @param firstMessage The first user message
|
|
69
69
|
* @param registry Model registry
|
|
70
70
|
* @param savedSmolModel Optional saved smol model from settings (provider/modelId format)
|
|
71
|
+
* @param sessionId Optional session id for sticky API key selection
|
|
71
72
|
*/
|
|
72
73
|
export async function generateSessionTitle(
|
|
73
74
|
firstMessage: string,
|
|
74
75
|
registry: ModelRegistry,
|
|
75
76
|
savedSmolModel?: string,
|
|
77
|
+
sessionId?: string,
|
|
76
78
|
): Promise<string | null> {
|
|
77
79
|
const candidates = getTitleModelCandidates(registry, savedSmolModel);
|
|
78
80
|
if (candidates.length === 0) {
|
|
@@ -86,7 +88,7 @@ export async function generateSessionTitle(
|
|
|
86
88
|
const userMessage = `<user-message>\n${truncatedMessage}\n</user-message>`;
|
|
87
89
|
|
|
88
90
|
for (const model of candidates) {
|
|
89
|
-
const apiKey = await registry.getApiKey(model);
|
|
91
|
+
const apiKey = await registry.getApiKey(model, sessionId);
|
|
90
92
|
if (!apiKey) {
|
|
91
93
|
logger.debug("title-generator: no API key for model", { provider: model.provider, id: model.id });
|
|
92
94
|
continue;
|