@robhowley/pi-openrouter 0.7.0 → 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,313 @@
1
+ /**
2
+ * Tests for the sync engine.
3
+ */
4
+
5
+ import { describe, it, expect, vi, beforeEach } from 'vitest';
6
+ import type { ExtensionContext, ModelRegistry } from '@mariozechner/pi-coding-agent';
7
+ import type { SyncResult } from '../types.js';
8
+
9
+ // Import modules
10
+ import {
11
+ syncModels,
12
+ setSyncState,
13
+ getSyncState,
14
+ getStatusText,
15
+ areModelsAvailable,
16
+ } from '../sync.js';
17
+ import { fetchUserModels, AuthError } from '../../client.js';
18
+ import { loadCache } from '../cache.js';
19
+
20
+ // Mock the client module to control API behavior
21
+ vi.mock('../../client.js', () => ({
22
+ fetchUserModels: vi.fn(),
23
+ isConfigured: vi.fn(),
24
+ getApiKey: vi.fn(),
25
+ AuthError: class AuthError extends Error {
26
+ constructor(message: string) {
27
+ super(message);
28
+ this.name = 'AuthError';
29
+ }
30
+ },
31
+ }));
32
+
33
+ // Mock the cache module
34
+ vi.mock('../cache.js', () => ({
35
+ loadCache: vi.fn(),
36
+ saveCache: vi.fn(),
37
+ setCacheDir: vi.fn(),
38
+ }));
39
+
40
+ /**
41
+ * Factory for creating minimal mock ExtensionContext.
42
+ * Only implements methods/properties actually used by sync tests.
43
+ */
44
+ function createMockExtensionContext(
45
+ overrides: {
46
+ registerProvider?: typeof vi.fn;
47
+ } = {},
48
+ ): ExtensionContext {
49
+ const mockFn = vi.fn;
50
+
51
+ return {
52
+ modelRegistry: {
53
+ registerProvider: overrides.registerProvider ?? mockFn(),
54
+ } as unknown as ModelRegistry,
55
+ ui: createMockUI(),
56
+ hasUI: true,
57
+ cwd: '/tmp',
58
+ sessionManager: createMockSessionManager(),
59
+ model: undefined,
60
+ isIdle: () => true,
61
+ signal: undefined,
62
+ abort: mockFn(),
63
+ hasPendingMessages: () => false,
64
+ shutdown: mockFn(),
65
+ getContextUsage: mockFn(),
66
+ compact: mockFn(),
67
+ getSystemPrompt: mockFn(),
68
+ } satisfies ExtensionContext;
69
+ }
70
+
71
+ function createMockUI(): ExtensionContext['ui'] {
72
+ const mockFn = vi.fn;
73
+ return {
74
+ select: mockFn(),
75
+ confirm: mockFn(),
76
+ input: mockFn(),
77
+ notify: mockFn(),
78
+ onTerminalInput: mockFn(),
79
+ setStatus: mockFn(),
80
+ setWorkingMessage: mockFn(),
81
+ setWorkingIndicator: mockFn(),
82
+ setHiddenThinkingLabel: mockFn(),
83
+ setWidget: mockFn(),
84
+ setFooter: mockFn(),
85
+ setHeader: mockFn(),
86
+ setTitle: mockFn(),
87
+ custom: mockFn(),
88
+ pasteToEditor: mockFn(),
89
+ setEditorText: mockFn(),
90
+ getEditorText: mockFn(),
91
+ editor: mockFn(),
92
+ setEditorComponent: mockFn(),
93
+ theme: {} as any,
94
+ getAllThemes: mockFn(),
95
+ getTheme: mockFn(),
96
+ setTheme: mockFn(),
97
+ getToolsExpanded: mockFn(),
98
+ setToolsExpanded: mockFn(),
99
+ };
100
+ }
101
+
102
+ /**
103
+ * Creates a mock session manager for testing.
104
+ * Uses `as any` since we only need the mock to satisfy ExtensionContext type,
105
+ * and sync tests don't actually use the session manager.
106
+ */
107
+ function createMockSessionManager(): ExtensionContext['sessionManager'] {
108
+ const mockFn = vi.fn;
109
+ return {
110
+ getCurrentSessionId: mockFn(),
111
+ getCurrentSessionPath: mockFn(),
112
+ getEntry: mockFn(),
113
+ getEntryHistory: mockFn(),
114
+ getEntryById: mockFn(),
115
+ getBranchEntries: mockFn(),
116
+ getBranchSummary: mockFn(),
117
+ getRecentEntries: mockFn(),
118
+ getEntryCount: mockFn(),
119
+ } as any;
120
+ }
121
+
122
+ describe('syncModels', () => {
123
+ const mockRegisterProvider = vi.fn();
124
+ const mockCtx = createMockExtensionContext({ registerProvider: mockRegisterProvider });
125
+
126
+ beforeEach(() => {
127
+ vi.resetAllMocks();
128
+ (setSyncState as (result: SyncResult | null) => void)(null);
129
+ // Explicitly delete API key
130
+ delete process.env['OPENROUTER_API_KEY'];
131
+ });
132
+
133
+ it('should return failure when API key is missing and no cache', async () => {
134
+ // Ensure API key is not set
135
+ delete process.env['OPENROUTER_API_KEY'];
136
+ // Mock fetchUserModels to throw AuthError
137
+ vi.mocked(fetchUserModels).mockRejectedValueOnce(new AuthError('OPENROUTER_API_KEY not set'));
138
+ // Mock loadCache to return null (no cache available)
139
+ vi.mocked(loadCache).mockResolvedValueOnce(null);
140
+
141
+ const result = await syncModels(mockCtx);
142
+
143
+ expect(result.success).toBe(false);
144
+ expect(result.registeredCount).toBe(0);
145
+ expect(result.source).toBe('none');
146
+ expect(result.error).toContain('OPENROUTER_API_KEY not set');
147
+ });
148
+
149
+ it('should sync models from API and register with provider', async () => {
150
+ // Mock successful API response with minimal model data
151
+ const mockModel = {
152
+ id: 'anthropic/claude-3-opus',
153
+ name: 'Claude 3 Opus',
154
+ architecture: {
155
+ inputModalities: ['text', 'image'],
156
+ outputModalities: ['text'],
157
+ },
158
+ contextLength: 200000,
159
+ pricing: {
160
+ prompt: 0.000015,
161
+ completion: 0.000075,
162
+ inputCacheRead: 0.0000015,
163
+ inputCacheWrite: 0.0000075,
164
+ },
165
+ supportedParameters: ['reasoning'],
166
+ topProvider: {
167
+ contextLength: 200000,
168
+ maxCompletionTokens: 4096,
169
+ },
170
+ };
171
+
172
+ vi.mocked(fetchUserModels).mockResolvedValueOnce({
173
+ data: [mockModel],
174
+ } as any);
175
+
176
+ vi.mocked(loadCache).mockResolvedValueOnce(null);
177
+
178
+ const result = await syncModels(mockCtx);
179
+
180
+ expect(result.success).toBe(true);
181
+ expect(result.source).toBe('api');
182
+ expect(result.registeredCount).toBeGreaterThan(0);
183
+ expect(mockRegisterProvider).toHaveBeenCalled();
184
+ });
185
+ });
186
+
187
+ describe('syncState management', () => {
188
+ it('should store and retrieve sync state', () => {
189
+ const mockResult: SyncResult = {
190
+ success: true,
191
+ registeredCount: 10,
192
+ skippedCount: 2,
193
+ source: 'api',
194
+ cacheUpdated: true,
195
+ cacheAgeMs: null,
196
+ error: null,
197
+ };
198
+
199
+ setSyncState(mockResult);
200
+
201
+ const retrieved = getSyncState();
202
+ expect(retrieved).toEqual(mockResult);
203
+ });
204
+
205
+ it('should return null when no state set', () => {
206
+ (setSyncState as (result: SyncResult | null) => void)(null);
207
+ expect(getSyncState()).toBeNull();
208
+ });
209
+ });
210
+
211
+ describe('getStatusText', () => {
212
+ beforeEach(() => {
213
+ (setSyncState as (result: SyncResult | null) => void)(null);
214
+ });
215
+
216
+ it('should return not synced when no state', () => {
217
+ expect(getStatusText()).toBe('OpenRouter models: not synced');
218
+ });
219
+
220
+ it('should return healthy for successful sync', () => {
221
+ setSyncState({
222
+ success: true,
223
+ registeredCount: 312,
224
+ skippedCount: 18,
225
+ source: 'api',
226
+ cacheUpdated: true,
227
+ cacheAgeMs: null,
228
+ error: null,
229
+ } as SyncResult);
230
+ const text = getStatusText();
231
+ expect(text).toContain('healthy');
232
+ expect(text).toContain('312 registered');
233
+ });
234
+
235
+ it('should return cached for cache fallback', () => {
236
+ setSyncState({
237
+ success: false,
238
+ registeredCount: 287,
239
+ skippedCount: 21,
240
+ source: 'cache',
241
+ cacheUpdated: false,
242
+ cacheAgeMs: 7200000, // 2 hours
243
+ error: '401 unauthorized',
244
+ } as SyncResult);
245
+ const text = getStatusText();
246
+ expect(text).toContain('cached');
247
+ expect(text).toContain('287 registered');
248
+ });
249
+
250
+ it('should return broken for complete failure', () => {
251
+ setSyncState({
252
+ success: false,
253
+ registeredCount: 0,
254
+ skippedCount: 0,
255
+ source: 'none',
256
+ cacheUpdated: false,
257
+ cacheAgeMs: null,
258
+ error: 'missing or invalid OpenRouter auth',
259
+ } as SyncResult);
260
+ const text = getStatusText();
261
+ expect(text).toContain('broken');
262
+ expect(text).toContain('0 registered');
263
+ });
264
+ });
265
+
266
+ describe('areModelsAvailable', () => {
267
+ beforeEach(() => {
268
+ (setSyncState as (result: SyncResult | null) => void)(null);
269
+ });
270
+
271
+ it('should return false when no state', async () => {
272
+ expect(await areModelsAvailable()).toBe(false);
273
+ });
274
+
275
+ it('should return true when models are synced', async () => {
276
+ setSyncState({
277
+ success: true,
278
+ registeredCount: 10,
279
+ skippedCount: 0,
280
+ source: 'api',
281
+ cacheUpdated: true,
282
+ cacheAgeMs: null,
283
+ error: null,
284
+ } as SyncResult);
285
+ expect(await areModelsAvailable()).toBe(true);
286
+ });
287
+
288
+ it('should return true when using cache (models still available)', async () => {
289
+ setSyncState({
290
+ success: false,
291
+ registeredCount: 5,
292
+ skippedCount: 0,
293
+ source: 'cache',
294
+ cacheUpdated: false,
295
+ cacheAgeMs: 7200000,
296
+ error: 'API error',
297
+ } as SyncResult);
298
+ expect(await areModelsAvailable()).toBe(true);
299
+ });
300
+
301
+ it('should return false when no models registered', async () => {
302
+ setSyncState({
303
+ success: false,
304
+ registeredCount: 0,
305
+ skippedCount: 0,
306
+ source: 'none',
307
+ cacheUpdated: false,
308
+ cacheAgeMs: null,
309
+ error: 'Complete failure',
310
+ } as SyncResult);
311
+ expect(await areModelsAvailable()).toBe(false);
312
+ });
313
+ });
@@ -0,0 +1,105 @@
1
+ import { readFile, writeFile, mkdir } from 'fs/promises';
2
+ import { join } from 'path';
3
+ import { homedir } from 'os';
4
+ import type { ModelsCache } from './types.js';
5
+ import { MS_PER_MINUTE } from './types.js';
6
+
7
+ const CACHE_FILENAME = 'models-cache.json';
8
+ const DEFAULT_CACHE_DIR = join(homedir(), '.pi', 'openrouter');
9
+
10
+ // Allow overriding cache directory for testing
11
+ let cacheDirOverride: string | null = null;
12
+
13
+ /**
14
+ * Get the cache directory.
15
+ * Uses override if set (for testing), otherwise uses default.
16
+ */
17
+ function getCacheDir(): string {
18
+ return cacheDirOverride ?? DEFAULT_CACHE_DIR;
19
+ }
20
+
21
+ /**
22
+ * Set a custom cache directory (for testing).
23
+ * Pass null to reset to default.
24
+ */
25
+ export function setCacheDir(dir: string | null): void {
26
+ cacheDirOverride = dir;
27
+ }
28
+
29
+ /**
30
+ * Get the full path to the cache file.
31
+ */
32
+ function getCachePath(): string {
33
+ return join(getCacheDir(), CACHE_FILENAME);
34
+ }
35
+
36
+ /**
37
+ * Ensure the cache directory exists.
38
+ */
39
+ async function ensureCacheDir(): Promise<void> {
40
+ await mkdir(getCacheDir(), { recursive: true });
41
+ }
42
+
43
+ /**
44
+ * Load cached models from disk.
45
+ * Returns null if cache doesn't exist or is corrupted.
46
+ */
47
+ export async function loadCache(): Promise<ModelsCache | null> {
48
+ try {
49
+ const cachePath = getCachePath();
50
+ const data = await readFile(cachePath, 'utf-8');
51
+ const parsed = JSON.parse(data) as ModelsCache;
52
+
53
+ // Validate structure
54
+ if (!parsed.models || !Array.isArray(parsed.models) || typeof parsed.timestamp !== 'number') {
55
+ return null;
56
+ }
57
+
58
+ return parsed;
59
+ } catch {
60
+ // File doesn't exist, permission error, or invalid JSON
61
+ return null;
62
+ }
63
+ }
64
+
65
+ /**
66
+ * Save models to cache on disk.
67
+ */
68
+ export async function saveCache(cache: ModelsCache): Promise<void> {
69
+ await ensureCacheDir();
70
+ const cachePath = getCachePath();
71
+ await writeFile(cachePath, JSON.stringify(cache, null, 2));
72
+ }
73
+
74
+ /**
75
+ * Get the age of the cache in milliseconds.
76
+ */
77
+ export function getCacheAgeMs(cache: ModelsCache): number {
78
+ return Date.now() - cache.timestamp;
79
+ }
80
+
81
+ /**
82
+ * Format milliseconds duration for display.
83
+ * Examples: "<1m", "4m", "2h", "1d"
84
+ */
85
+ export function formatDuration(ms: number | null): string {
86
+ if (ms === null) return 'unknown';
87
+
88
+ const minutes = Math.floor(ms / MS_PER_MINUTE);
89
+ if (minutes < 1) return '<1m';
90
+ if (minutes < 60) return `${minutes}m`;
91
+
92
+ const hours = Math.floor(minutes / 60);
93
+ if (hours < 24) return `${hours}h`;
94
+
95
+ return `${Math.floor(hours / 24)}d`;
96
+ }
97
+
98
+ /**
99
+ * Format cache age for display.
100
+ * Examples: "4m", "2h", "1d" (returns null for null cache)
101
+ */
102
+ export function formatCacheAge(cache: ModelsCache | null): string | null {
103
+ if (!cache) return null;
104
+ return formatDuration(getCacheAgeMs(cache));
105
+ }
@@ -0,0 +1,182 @@
1
+ import type { OpenRouterModel, PiModelConfig, SkipReason, MapResult } from './types.js';
2
+ import { ROUTER_ALIASES } from './types.js';
3
+ import type { Model as SDKModel } from '@openrouter/sdk/models/index.js';
4
+
5
+ const COST_PER_MILLION = 1_000_000;
6
+ const DEFAULT_MAX_TOKENS = 4096;
7
+
8
+ /**
9
+ * Convert SDK Model to our OpenRouterModel type for compatibility.
10
+ * Handles SDK's camelCase naming convention.
11
+ */
12
+ export function sdkModelToOpenRouterModel(model: SDKModel): OpenRouterModel {
13
+ const topProvider = model.topProvider
14
+ ? {
15
+ context_length: model.topProvider.contextLength ?? 0,
16
+ max_completion_tokens: model.topProvider.maxCompletionTokens ?? 0,
17
+ }
18
+ : undefined;
19
+
20
+ const perRequestLimits = model.perRequestLimits
21
+ ? {
22
+ completion_tokens: model.perRequestLimits.completionTokens ?? 0,
23
+ }
24
+ : undefined;
25
+
26
+ // Build the object conditionally to avoid undefined property issues
27
+ const result: OpenRouterModel = {
28
+ id: model.id,
29
+ name: model.name,
30
+ architecture: {
31
+ input_modalities: model.architecture.inputModalities ?? [],
32
+ output_modalities: model.architecture.outputModalities ?? [],
33
+ },
34
+ context_length: model.contextLength ?? 0,
35
+ pricing: {
36
+ prompt: String(model.pricing.prompt ?? 0),
37
+ completion: String(model.pricing.completion ?? 0),
38
+ input_cache_read: String(model.pricing.inputCacheRead ?? 0),
39
+ input_cache_write: String(model.pricing.inputCacheWrite ?? 0),
40
+ },
41
+ supported_parameters: model.supportedParameters,
42
+ };
43
+
44
+ // Conditionally add optional properties to avoid explicit undefined
45
+ if (topProvider) {
46
+ result.top_provider = topProvider;
47
+ }
48
+ if (perRequestLimits) {
49
+ result.per_request_limits = perRequestLimits;
50
+ }
51
+
52
+ return result;
53
+ }
54
+
55
+ /**
56
+ * Normalize input model to OpenRouterModel format.
57
+ */
58
+ function normalizeModel(model: OpenRouterModel | SDKModel): OpenRouterModel {
59
+ return 'contextLength' in model
60
+ ? sdkModelToOpenRouterModel(model as SDKModel)
61
+ : (model as OpenRouterModel);
62
+ }
63
+
64
+ /**
65
+ * Validation result for a model check.
66
+ */
67
+ type ValidationResult =
68
+ | { valid: true; model: OpenRouterModel; contextWindow: number }
69
+ | { valid: false; reason: string; modelId: string };
70
+
71
+ /**
72
+ * Validate a model and return either a valid result with extracted context window
73
+ * or a failure reason.
74
+ */
75
+ function validateModel(model: OpenRouterModel): ValidationResult {
76
+ // Check: missing required id
77
+ if (!model.id) {
78
+ return { valid: false, reason: 'missing id', modelId: 'unknown' };
79
+ }
80
+
81
+ // Check: missing required pricing fields
82
+ if (!model.pricing?.prompt) {
83
+ return { valid: false, reason: 'missing prompt pricing', modelId: model.id };
84
+ }
85
+ if (!model.pricing?.completion) {
86
+ return { valid: false, reason: 'missing completion pricing', modelId: model.id };
87
+ }
88
+
89
+ // Check: missing context window (both primary and fallback)
90
+ const contextWindow = model.top_provider?.context_length ?? model.context_length;
91
+ if (!contextWindow) {
92
+ return { valid: false, reason: 'missing context window', modelId: model.id };
93
+ }
94
+
95
+ // Check: explicitly non-text output (if specified)
96
+ const outputModalities = model.architecture?.output_modalities;
97
+ if (outputModalities && !outputModalities.includes('text')) {
98
+ return { valid: false, reason: 'non-text output modalities', modelId: model.id };
99
+ }
100
+
101
+ return { valid: true, model, contextWindow };
102
+ }
103
+
104
+ /**
105
+ * Build PiModelConfig from a validated OpenRouterModel.
106
+ */
107
+ function buildPiConfig(model: OpenRouterModel, contextWindow: number): PiModelConfig {
108
+ const supportedParams = model.supported_parameters ?? [];
109
+ const hasReasoning =
110
+ supportedParams.includes('reasoning') || supportedParams.includes('include_reasoning');
111
+ const inputModalities = model.architecture?.input_modalities;
112
+ const supportsImages = inputModalities?.includes('image') ?? false;
113
+
114
+ return {
115
+ id: model.id,
116
+ name: model.name ?? model.id,
117
+ reasoning: hasReasoning,
118
+ input: supportsImages ? ['text', 'image'] : ['text'],
119
+ cost: {
120
+ input: Number(model.pricing.prompt) * COST_PER_MILLION,
121
+ output: Number(model.pricing.completion) * COST_PER_MILLION,
122
+ cacheRead: Number(model.pricing.input_cache_read ?? 0) * COST_PER_MILLION,
123
+ cacheWrite: Number(model.pricing.input_cache_write ?? 0) * COST_PER_MILLION,
124
+ },
125
+ contextWindow,
126
+ maxTokens:
127
+ model.top_provider?.max_completion_tokens ??
128
+ model.per_request_limits?.completion_tokens ??
129
+ DEFAULT_MAX_TOKENS,
130
+ };
131
+ }
132
+
133
+ /**
134
+ * Maps multiple OpenRouter models, tracking skips.
135
+ */
136
+ export function mapOpenRouterModels(models: OpenRouterModel[] | SDKModel[]): MapResult {
137
+ const configs: PiModelConfig[] = [];
138
+ let skipped = 0;
139
+ const skippedDetails: SkipReason[] = [];
140
+
141
+ for (const rawModel of models) {
142
+ const model = normalizeModel(rawModel);
143
+
144
+ // Skip router aliases - they're added manually after mapping
145
+ if (ROUTER_ALIASES.includes(model.id)) {
146
+ continue;
147
+ }
148
+
149
+ const validation = validateModel(model);
150
+
151
+ if (!validation.valid) {
152
+ skipped++;
153
+ skippedDetails.push({ id: validation.modelId, reason: validation.reason });
154
+ continue;
155
+ }
156
+
157
+ configs.push(buildPiConfig(model, validation.contextWindow));
158
+ }
159
+
160
+ return { configs, skipped, skippedDetails };
161
+ }
162
+
163
+ /**
164
+ * Maps a single OpenRouter model to Pi model config.
165
+ * Returns null if the model should be skipped.
166
+ */
167
+ export function mapOpenRouterModel(model: OpenRouterModel | SDKModel): PiModelConfig | null {
168
+ const normalized = normalizeModel(model);
169
+
170
+ // Router aliases are handled separately, skip them here
171
+ if (ROUTER_ALIASES.includes(normalized.id)) {
172
+ return null;
173
+ }
174
+
175
+ const validation = validateModel(normalized);
176
+
177
+ if (!validation.valid) {
178
+ return null;
179
+ }
180
+
181
+ return buildPiConfig(normalized, validation.contextWindow);
182
+ }