voyageai-cli 1.22.1 → 1.23.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,252 @@
1
+ 'use strict';
2
+
3
+ /**
4
+ * Chat Orchestrator
5
+ *
6
+ * Coordinates the retrieval pipeline (embed → search → rerank)
7
+ * with LLM generation and history management.
8
+ */
9
+
10
+ const { generateEmbeddings, apiRequest } = require('./api');
11
+
12
+ /**
13
+ * Build a human-readable source label from a document.
14
+ * Tries metadata fields that identify the document (title, name, etc.)
15
+ * before falling back to the raw source filename.
16
+ */
17
+ function resolveSourceLabel(doc) {
18
+ const meta = doc.metadata || {};
19
+
20
+ // Try common identifying fields from the document metadata
21
+ const identifiers = ['title', 'name', 'subject', 'heading', 'filename'];
22
+ for (const key of identifiers) {
23
+ if (meta[key] && typeof meta[key] === 'string') {
24
+ const label = meta[key];
25
+ // Append year if available (common for movies/articles)
26
+ if (meta.year) return `${label} (${meta.year})`;
27
+ return label;
28
+ }
29
+ }
30
+
31
+ // Fall back to source path / _id
32
+ return doc.source || meta.source || doc._id?.toString() || 'unknown';
33
+ }
34
+ const { getMongoCollection } = require('./mongo');
35
+ const { buildMessages } = require('./prompt');
36
+ const { getDefaultModel, DEFAULT_RERANK_MODEL } = require('./catalog');
37
+ const { loadProject } = require('./project');
38
+
39
+ /**
40
+ * Perform retrieval: embed query → vector search → optional rerank.
41
+ *
42
+ * @param {object} params
43
+ * @param {string} params.query - User's question
44
+ * @param {string} params.db - MongoDB database name
45
+ * @param {string} params.collection - Collection with embedded docs
46
+ * @param {object} [params.opts] - Additional options
47
+ * @param {string} [params.opts.model] - Embedding model
48
+ * @param {string} [params.opts.index] - Vector search index name
49
+ * @param {string} [params.opts.field] - Embedding field name
50
+ * @param {number} [params.opts.dimensions] - Embedding dimensions
51
+ * @param {number} [params.opts.maxDocs] - Max documents to return
52
+ * @param {boolean} [params.opts.rerank] - Whether to rerank (default true)
53
+ * @param {string} [params.opts.textField] - Document text field name
54
+ * @param {string} [params.opts.filter] - JSON pre-filter for vector search
55
+ * @returns {Promise<{docs: Array, client: MongoClient, retrievalTimeMs: number, tokens: {embed: number, rerank: number}}>}
56
+ */
57
+ async function retrieve({ query, db, collection, opts = {} }) {
58
+ const { config: proj } = loadProject();
59
+ const model = opts.model || proj.model || getDefaultModel();
60
+ const index = opts.index || proj.index || 'vector_index';
61
+ const field = opts.field || proj.field || 'embedding';
62
+ const dimensions = opts.dimensions || proj.dimensions;
63
+ const maxDocs = opts.maxDocs || 5;
64
+ const doRerank = opts.rerank !== false;
65
+ const textField = opts.textField || 'text';
66
+ const limit = Math.min(maxDocs * 4, 20); // Get more candidates for reranking
67
+
68
+ const start = Date.now();
69
+
70
+ // Step 1: Embed query
71
+ const embedOpts = { model, inputType: 'query' };
72
+ if (dimensions) embedOpts.dimensions = dimensions;
73
+ const embedResult = await generateEmbeddings([query], embedOpts);
74
+ const queryVector = embedResult.data[0].embedding;
75
+ const embedTokens = embedResult.usage?.total_tokens || 0;
76
+
77
+ // Step 2: Vector search
78
+ const { client, collection: coll } = await getMongoCollection(db, collection);
79
+
80
+ const vectorSearchStage = {
81
+ index,
82
+ path: field,
83
+ queryVector,
84
+ numCandidates: Math.min(limit * 15, 10000),
85
+ limit,
86
+ };
87
+
88
+ if (opts.filter) {
89
+ try {
90
+ vectorSearchStage.filter = typeof opts.filter === 'string'
91
+ ? JSON.parse(opts.filter)
92
+ : opts.filter;
93
+ } catch {
94
+ throw new Error('Invalid --filter JSON.');
95
+ }
96
+ }
97
+
98
+ const pipeline = [
99
+ { $vectorSearch: vectorSearchStage },
100
+ { $addFields: { _vsScore: { $meta: 'vectorSearchScore' } } },
101
+ ];
102
+
103
+ const searchResults = await coll.aggregate(pipeline).toArray();
104
+
105
+ if (searchResults.length === 0) {
106
+ return { docs: [], client, retrievalTimeMs: Date.now() - start, tokens: { embed: embedTokens, rerank: 0 } };
107
+ }
108
+
109
+ // Step 3: Rerank (optional)
110
+ let finalDocs;
111
+ let rerankTokens = 0;
112
+
113
+ if (doRerank && searchResults.length > 1) {
114
+ const rerankModel = opts.rerankModel || DEFAULT_RERANK_MODEL;
115
+ const documents = searchResults.map(doc => {
116
+ const txt = doc[textField];
117
+ if (!txt) return JSON.stringify(doc);
118
+ return typeof txt === 'string' ? txt : JSON.stringify(txt);
119
+ });
120
+
121
+ const rerankResult = await apiRequest('/rerank', {
122
+ query,
123
+ documents,
124
+ model: rerankModel,
125
+ top_k: maxDocs,
126
+ });
127
+ rerankTokens = rerankResult.usage?.total_tokens || 0;
128
+
129
+ finalDocs = (rerankResult.data || []).map(item => {
130
+ const doc = searchResults[item.index];
131
+ return {
132
+ text: doc[textField] || '',
133
+ source: resolveSourceLabel(doc),
134
+ score: item.relevance_score,
135
+ vectorScore: doc._vsScore,
136
+ metadata: doc.metadata || {},
137
+ };
138
+ });
139
+ } else {
140
+ finalDocs = searchResults.slice(0, maxDocs).map(doc => ({
141
+ text: doc[textField] || '',
142
+ source: resolveSourceLabel(doc),
143
+ score: doc._vsScore,
144
+ metadata: doc.metadata || {},
145
+ }));
146
+ }
147
+
148
+ return {
149
+ docs: finalDocs,
150
+ client,
151
+ retrievalTimeMs: Date.now() - start,
152
+ tokens: { embed: embedTokens, rerank: rerankTokens },
153
+ };
154
+ }
155
+
156
+ /**
157
+ * Execute a single chat turn: retrieve context → build prompt → generate response.
158
+ *
159
+ * @param {object} params
160
+ * @param {string} params.query - User's question
161
+ * @param {string} params.db - MongoDB database name
162
+ * @param {string} params.collection - Collection name
163
+ * @param {object} params.llm - LLM provider instance
164
+ * @param {import('./history').ChatHistory} params.history - Chat history
165
+ * @param {object} [params.opts] - Additional options
166
+ * @param {string} [params.opts.systemPrompt] - Custom system prompt
167
+ * @param {number} [params.opts.maxDocs] - Max context docs
168
+ * @param {boolean} [params.opts.rerank] - Whether to rerank
169
+ * @param {boolean} [params.opts.stream] - Whether to stream (default true)
170
+ * @param {string} [params.opts.textField] - Document text field
171
+ * @param {string} [params.opts.filter] - Vector search pre-filter
172
+ * @returns {AsyncGenerator<{type: string, data: any}>}
173
+ * Yields: { type: 'retrieval', data: { docs, timeMs, tokens } }
174
+ * { type: 'chunk', data: string }
175
+ * { type: 'done', data: { fullResponse, sources, metadata } }
176
+ */
177
+ async function* chatTurn({ query, db, collection, llm, history, opts = {} }) {
178
+ const genStart = Date.now();
179
+
180
+ // 1. Retrieve context
181
+ const { docs, client, retrievalTimeMs, tokens } = await retrieve({
182
+ query, db, collection,
183
+ opts: {
184
+ maxDocs: opts.maxDocs,
185
+ rerank: opts.rerank,
186
+ textField: opts.textField,
187
+ filter: opts.filter,
188
+ },
189
+ });
190
+
191
+ yield { type: 'retrieval', data: { docs, timeMs: retrievalTimeMs, tokens } };
192
+
193
+ // 2. Build messages
194
+ const messages = buildMessages({
195
+ query,
196
+ contextDocs: docs,
197
+ history: history.getMessagesWithBudget(8000),
198
+ systemPrompt: opts.systemPrompt,
199
+ });
200
+
201
+ // 3. Generate response (streaming)
202
+ let fullResponse = '';
203
+ const stream = opts.stream !== false;
204
+
205
+ try {
206
+ for await (const chunk of llm.chat(messages, { stream })) {
207
+ fullResponse += chunk;
208
+ yield { type: 'chunk', data: chunk };
209
+ }
210
+ } finally {
211
+ // Always close the retrieval client
212
+ if (client) {
213
+ try { await client.close(); } catch { /* ignore */ }
214
+ }
215
+ }
216
+
217
+ const generationTimeMs = Date.now() - genStart - retrievalTimeMs;
218
+
219
+ // 4. Store turns in history
220
+ await history.addTurn({ role: 'user', content: query });
221
+ await history.addTurn({
222
+ role: 'assistant',
223
+ content: fullResponse,
224
+ context: docs,
225
+ metadata: {
226
+ llmProvider: llm.name,
227
+ llmModel: llm.model,
228
+ retrievalTimeMs,
229
+ generationTimeMs,
230
+ contextDocsUsed: docs.length,
231
+ },
232
+ });
233
+
234
+ yield {
235
+ type: 'done',
236
+ data: {
237
+ fullResponse,
238
+ sources: docs.map(d => ({ source: d.source, score: d.score })),
239
+ metadata: {
240
+ retrievalTimeMs,
241
+ generationTimeMs,
242
+ tokens,
243
+ contextDocsUsed: docs.length,
244
+ },
245
+ },
246
+ };
247
+ }
248
+
249
+ module.exports = {
250
+ retrieve,
251
+ chatTurn,
252
+ };
@@ -103,7 +103,7 @@ function render(template, context = {}) {
103
103
  result = processUnlessBlocks(result, context);
104
104
 
105
105
  // Process simple variable substitutions {{variable}} and {{variable.nested}}
106
- result = result.replace(/\{\{([a-zA-Z_][\w.]*)\}\}/g, (match, varPath) => {
106
+ result = result.replace(/\{\{(@?[a-zA-Z_][\w.]*)\}\}/g, (match, varPath) => {
107
107
  const value = getPath(context, varPath);
108
108
  if (value === undefined || value === null) return '';
109
109
  if (typeof value === 'object') return JSON.stringify(value);
@@ -181,7 +181,7 @@ function processIfBlocks(template, context) {
181
181
 
182
182
  // Match if-else blocks that don't contain nested {{#if (innermost first)
183
183
  // This regex ensures we don't have another {{#if inside the captured groups
184
- const ifElseRegex = /\{\{#if\s+(\w+(?:\.\w+)*)\}\}((?:(?!\{\{#if)[\s\S])*?)\{\{else\}\}((?:(?!\{\{#if)[\s\S])*?)\{\{\/if\}\}/;
184
+ const ifElseRegex = /\{\{#if\s+(@?\w+(?:\.\w+)*)\}\}((?:(?!\{\{#if)[\s\S])*?)\{\{else\}\}((?:(?!\{\{#if)[\s\S])*?)\{\{\/if\}\}/;
185
185
 
186
186
  let match = result.match(ifElseRegex);
187
187
  if (match) {
@@ -194,7 +194,7 @@ function processIfBlocks(template, context) {
194
194
  }
195
195
 
196
196
  // Match simple if blocks that don't contain nested {{#if
197
- const ifRegex = /\{\{#if\s+(\w+(?:\.\w+)*)\}\}((?:(?!\{\{#if)[\s\S])*?)\{\{\/if\}\}/;
197
+ const ifRegex = /\{\{#if\s+(@?\w+(?:\.\w+)*)\}\}((?:(?!\{\{#if)[\s\S])*?)\{\{\/if\}\}/;
198
198
 
199
199
  match = result.match(ifRegex);
200
200
  if (match) {
@@ -213,7 +213,7 @@ function processIfBlocks(template, context) {
213
213
  * Process {{#unless condition}}...{{/unless}} blocks.
214
214
  */
215
215
  function processUnlessBlocks(template, context) {
216
- const unlessRegex = /\{\{#unless\s+(\w+(?:\.\w+)*)\}\}([\s\S]*?)\{\{\/unless\}\}/g;
216
+ const unlessRegex = /\{\{#unless\s+(@?\w+(?:\.\w+)*)\}\}([\s\S]*?)\{\{\/unless\}\}/g;
217
217
 
218
218
  return template.replace(unlessRegex, (match, varPath, blockContent) => {
219
219
  const value = getPath(context, varPath);
@@ -322,6 +322,7 @@ function buildContext(project, options = {}) {
322
322
  // Metadata
323
323
  generatedAt: new Date().toISOString(),
324
324
  vaiVersion: getCliVersion(),
325
+ vaiVersion: require('../../package.json').version,
325
326
  };
326
327
 
327
328
  return context;
package/src/lib/config.js CHANGED
@@ -14,10 +14,14 @@ const KEY_MAP = {
14
14
  'default-model': 'defaultModel',
15
15
  'default-dimensions': 'defaultDimensions',
16
16
  'base-url': 'baseUrl',
17
+ 'llm-provider': 'llmProvider',
18
+ 'llm-api-key': 'llmApiKey',
19
+ 'llm-model': 'llmModel',
20
+ 'llm-base-url': 'llmBaseUrl',
17
21
  };
18
22
 
19
23
  // Keys whose values should be masked in output
20
- const SECRET_KEYS = new Set(['apiKey', 'mongodbUri']);
24
+ const SECRET_KEYS = new Set(['apiKey', 'mongodbUri', 'llmApiKey']);
21
25
 
22
26
  /**
23
27
  * Load config from disk. Returns {} if file doesn't exist.
@@ -0,0 +1,352 @@
1
+ 'use strict';
2
+
3
+ /**
4
+ * Token and cost estimation utilities.
5
+ *
6
+ * Shared by all commands that support --estimate.
7
+ * Uses the model catalog for pricing and a ~4 chars/token heuristic.
8
+ */
9
+
10
+ const { MODEL_CATALOG } = require('./catalog');
11
+
12
+ /**
13
+ * Estimate token count from text (~4 chars per token).
14
+ * @param {string} text
15
+ * @returns {number}
16
+ */
17
+ function estimateTokens(text) {
18
+ if (!text) return 0;
19
+ return Math.ceil(text.length / 4);
20
+ }
21
+
22
+ /**
23
+ * Estimate tokens from an array of texts.
24
+ * @param {string[]} texts
25
+ * @returns {number}
26
+ */
27
+ function estimateTokensForTexts(texts) {
28
+ return texts.reduce((sum, t) => sum + estimateTokens(t), 0);
29
+ }
30
+
31
+ /**
32
+ * Look up per-million-token price for a model.
33
+ * @param {string} modelName
34
+ * @returns {number|null} price per 1M tokens, or null if unknown
35
+ */
36
+ function getModelPrice(modelName) {
37
+ const model = MODEL_CATALOG.find(m => m.name === modelName);
38
+ return model?.pricePerMToken ?? null;
39
+ }
40
+
41
+ /**
42
+ * Calculate estimated cost.
43
+ * @param {number} tokens - estimated token count
44
+ * @param {string} modelName - model name from catalog
45
+ * @returns {{ tokens: number, cost: number|null, model: string, pricePerMToken: number|null }}
46
+ */
47
+ function estimateCost(tokens, modelName) {
48
+ const pricePerMToken = getModelPrice(modelName);
49
+ const cost = pricePerMToken != null ? (tokens / 1_000_000) * pricePerMToken : null;
50
+ return { tokens, cost, model: modelName, pricePerMToken };
51
+ }
52
+
53
+ /**
54
+ * Estimate cost for a chat turn (embedding query + reranking + LLM generation).
55
+ * @param {object} params
56
+ * @param {string} params.query - user's question text
57
+ * @param {number} params.contextDocs - number of context docs
58
+ * @param {number} params.avgDocTokens - average tokens per context doc (default 200)
59
+ * @param {string} params.embeddingModel - Voyage embedding model
60
+ * @param {string} params.rerankModel - Voyage rerank model (optional)
61
+ * @param {string} params.llmProvider - 'anthropic' | 'openai' | 'ollama'
62
+ * @param {string} params.llmModel - specific LLM model name
63
+ * @param {number} params.historyTurns - number of conversation turns in context (default 0)
64
+ * @returns {object} breakdown with per-stage estimates
65
+ */
66
+ function estimateChatCost({
67
+ query,
68
+ contextDocs = 5,
69
+ avgDocTokens = 200,
70
+ embeddingModel = 'voyage-4-large',
71
+ rerankModel = 'rerank-2.5',
72
+ llmProvider = 'anthropic',
73
+ llmModel,
74
+ historyTurns = 0,
75
+ }) {
76
+ const queryTokens = estimateTokens(query);
77
+ const contextTokens = contextDocs * avgDocTokens;
78
+ const historyTokens = historyTurns * 150; // ~150 tokens per turn pair
79
+ const systemPromptTokens = 100; // rough estimate
80
+
81
+ // Stage 1: Embedding the query
82
+ const embedCost = estimateCost(queryTokens, embeddingModel);
83
+
84
+ // Stage 2: Reranking candidates
85
+ const rerankTokens = queryTokens + (contextDocs * avgDocTokens);
86
+ const rerankCost = rerankModel ? estimateCost(rerankTokens, rerankModel) : null;
87
+
88
+ // Stage 3: LLM generation
89
+ const llmInputTokens = systemPromptTokens + contextTokens + historyTokens + queryTokens;
90
+ const llmOutputTokens = 300; // estimated response length
91
+ const llmCost = estimateLLMCost(llmProvider, llmModel, llmInputTokens, llmOutputTokens);
92
+
93
+ const totalCost = (embedCost.cost || 0)
94
+ + (rerankCost?.cost || 0)
95
+ + (llmCost?.cost || 0);
96
+
97
+ return {
98
+ embed: embedCost,
99
+ rerank: rerankCost,
100
+ llm: llmCost,
101
+ totalTokens: queryTokens + rerankTokens + llmInputTokens + llmOutputTokens,
102
+ totalCost,
103
+ };
104
+ }
105
+
106
+ /**
107
+ * Rough LLM cost estimation (cloud providers only).
108
+ * @param {string} provider
109
+ * @param {string} model
110
+ * @param {number} inputTokens
111
+ * @param {number} outputTokens
112
+ * @returns {{ inputTokens: number, outputTokens: number, cost: number|null, model: string }}
113
+ */
114
+ function estimateLLMCost(provider, model, inputTokens, outputTokens) {
115
+ // Approximate pricing per 1M tokens (input/output)
116
+ const LLM_PRICING = {
117
+ anthropic: {
118
+ 'claude-sonnet-4-5-20250929': { input: 3.0, output: 15.0 },
119
+ 'claude-opus-4-20250514': { input: 15.0, output: 75.0 },
120
+ 'claude-3-5-haiku-20241022': { input: 1.0, output: 5.0 },
121
+ },
122
+ openai: {
123
+ 'gpt-4o': { input: 2.5, output: 10.0 },
124
+ 'gpt-4o-mini': { input: 0.15, output: 0.6 },
125
+ 'gpt-4-turbo': { input: 10.0, output: 30.0 },
126
+ 'o1': { input: 15.0, output: 60.0 },
127
+ 'o1-mini': { input: 3.0, output: 12.0 },
128
+ 'o3-mini': { input: 1.1, output: 4.4 },
129
+ },
130
+ ollama: {}, // all free
131
+ };
132
+
133
+ const providerPricing = LLM_PRICING[provider] || {};
134
+ const modelPricing = providerPricing[model];
135
+
136
+ let cost = null;
137
+ if (provider === 'ollama') {
138
+ cost = 0;
139
+ } else if (modelPricing) {
140
+ cost = (inputTokens / 1_000_000) * modelPricing.input
141
+ + (outputTokens / 1_000_000) * modelPricing.output;
142
+ }
143
+
144
+ return { inputTokens, outputTokens, cost, model: model || 'unknown' };
145
+ }
146
+
147
+ /**
148
+ * Estimate cost across all comparable Voyage models.
149
+ * @param {number} tokens - estimated token count
150
+ * @param {string} selectedModel - the user's chosen model
151
+ * @returns {Array<{ model: string, tokens: number, cost: number, pricePerMToken: number, selected: boolean, shortFor: string }>}
152
+ */
153
+ function estimateCostComparison(tokens, selectedModel) {
154
+ // Find the type of the selected model to compare apples-to-apples
155
+ const selected = MODEL_CATALOG.find(m => m.name === selectedModel);
156
+ const type = selected?.type || 'embedding';
157
+
158
+ // For embeddings, only show general-purpose models (voyage-4 family)
159
+ // plus the selected model. Skip domain-specific (finance, law, code)
160
+ // unless the user explicitly selected one of those.
161
+ const isGeneralPurpose = (m) => {
162
+ if (m.name === selectedModel) return true; // always include selected
163
+ if (type !== 'embedding') return true; // no filtering for rerank etc.
164
+ // Domain-specific models have specific bestFor keywords
165
+ const dominated = ['finance', 'legal', 'code', 'context'];
166
+ return !dominated.some(d => (m.bestFor || '').toLowerCase().includes(d));
167
+ };
168
+
169
+ return MODEL_CATALOG
170
+ .filter(m => m.type === type && !m.legacy && !m.unreleased && m.pricePerMToken != null && isGeneralPurpose(m))
171
+ .map(m => ({
172
+ model: m.name,
173
+ tokens,
174
+ cost: (tokens / 1_000_000) * m.pricePerMToken,
175
+ pricePerMToken: m.pricePerMToken,
176
+ selected: m.name === selectedModel,
177
+ shortFor: m.shortFor || m.bestFor || '',
178
+ }))
179
+ .sort((a, b) => b.pricePerMToken - a.pricePerMToken); // highest price first
180
+ }
181
+
182
+ /**
183
+ * Format a cost estimate for terminal display with model comparison.
184
+ * @param {object} estimate - from estimateCost()
185
+ * @returns {string}
186
+ */
187
+ function formatCostEstimate(estimate) {
188
+ const pc = require('picocolors');
189
+ const lines = [];
190
+
191
+ const comparison = estimateCostComparison(estimate.tokens, estimate.model);
192
+
193
+ lines.push(pc.bold(` Cost Estimate — ${estimate.tokens.toLocaleString()} tokens`));
194
+ lines.push('');
195
+
196
+ if (comparison.length > 1) {
197
+ // Table header
198
+ lines.push(` ${pc.dim(padRight('Model', 22))} ${pc.dim(padRight('Quality', 14))} ${pc.dim(padRight('Price/1M', 10))} ${pc.dim('Est. Cost')}`);
199
+ lines.push(` ${pc.dim('─'.repeat(60))}`);
200
+
201
+ for (const row of comparison) {
202
+ const costStr = row.cost < 0.001 ? '< $0.001' : `$${row.cost.toFixed(4)}`;
203
+ const marker = row.selected ? pc.green(' ← selected') : '';
204
+ const nameStr = row.selected ? pc.bold(row.model) : row.model;
205
+ lines.push(` ${padRight(nameStr, 22)} ${padRight(row.shortFor, 14)} $${padRight(row.pricePerMToken.toFixed(2), 9)} ${pc.cyan(costStr)}${marker}`);
206
+ }
207
+ } else {
208
+ // Single model fallback
209
+ lines.push(` Model: ${estimate.model}`);
210
+ if (estimate.cost != null) {
211
+ lines.push(` Cost: ${pc.cyan(`$${estimate.cost.toFixed(4)}`)}`);
212
+ } else {
213
+ lines.push(` Cost: unknown pricing`);
214
+ }
215
+ }
216
+
217
+ return lines.join('\n');
218
+ }
219
+
220
+ function padRight(str, len) {
221
+ // Strip ANSI for length calculation
222
+ const stripped = str.replace(/\x1b\[[0-9;]*m/g, '');
223
+ const pad = Math.max(0, len - stripped.length);
224
+ return str + ' '.repeat(pad);
225
+ }
226
+
227
+ /**
228
+ * Format a chat cost breakdown for terminal display.
229
+ * @param {object} breakdown - from estimateChatCost()
230
+ * @returns {string}
231
+ */
232
+ function formatChatCostBreakdown(breakdown) {
233
+ const pc = require('picocolors');
234
+ const lines = [];
235
+
236
+ lines.push(pc.bold(' Chat Cost Estimate (per turn)'));
237
+ lines.push(` ${pc.dim('─'.repeat(40))}`);
238
+
239
+ // Embedding
240
+ const embedPrice = breakdown.embed.cost != null
241
+ ? `$${breakdown.embed.cost.toFixed(6)}`
242
+ : '?';
243
+ lines.push(` ${pc.dim('Embed query:')} ${breakdown.embed.tokens.toLocaleString()} tokens ${pc.dim(embedPrice)}`);
244
+
245
+ // Reranking
246
+ if (breakdown.rerank) {
247
+ const rerankPrice = breakdown.rerank.cost != null
248
+ ? `$${breakdown.rerank.cost.toFixed(6)}`
249
+ : '?';
250
+ lines.push(` ${pc.dim('Rerank:')} ${breakdown.rerank.tokens.toLocaleString()} tokens ${pc.dim(rerankPrice)}`);
251
+ }
252
+
253
+ // LLM
254
+ const llmPrice = breakdown.llm.cost != null
255
+ ? `$${breakdown.llm.cost.toFixed(6)}`
256
+ : (breakdown.llm.model === 'ollama' ? 'free' : '?');
257
+ lines.push(` ${pc.dim('LLM input:')} ${breakdown.llm.inputTokens.toLocaleString()} tokens`);
258
+ lines.push(` ${pc.dim('LLM output:')} ~${breakdown.llm.outputTokens.toLocaleString()} tokens ${pc.dim(llmPrice)}`);
259
+
260
+ // Total
261
+ lines.push(` ${pc.dim('─'.repeat(40))}`);
262
+ const totalStr = breakdown.totalCost < 0.001
263
+ ? `< $0.001`
264
+ : `~$${breakdown.totalCost.toFixed(4)}`;
265
+ lines.push(` ${pc.bold('Total:')} ~${breakdown.totalTokens.toLocaleString()} tokens ${pc.cyan(totalStr)}`);
266
+
267
+ return lines.join('\n');
268
+ }
269
+
270
+ /**
271
+ * Show cost estimate and let user confirm or switch models interactively.
272
+ * Returns the chosen model name, or null if cancelled.
273
+ *
274
+ * @param {number} tokens - estimated token count
275
+ * @param {string} selectedModel - current model
276
+ * @param {object} [opts]
277
+ * @param {boolean} [opts.json] - if true, skip interactive and return selected
278
+ * @param {boolean} [opts.nonInteractive] - if true, just display and return selected
279
+ * @returns {Promise<string|null>} chosen model name, or null if cancelled
280
+ */
281
+ async function confirmOrSwitchModel(tokens, selectedModel, opts = {}) {
282
+ const pc = require('picocolors');
283
+ const est = estimateCost(tokens, selectedModel);
284
+
285
+ // Display the comparison table
286
+ console.log('');
287
+ console.log(formatCostEstimate(est));
288
+ console.log('');
289
+
290
+ if (opts.json || opts.nonInteractive) {
291
+ return selectedModel;
292
+ }
293
+
294
+ // Build choices: proceed with current, switch to each alternative, cancel
295
+ const comparison = estimateCostComparison(tokens, selectedModel);
296
+ const p = require('@clack/prompts');
297
+
298
+ const options = [];
299
+
300
+ // Current model first
301
+ const currentRow = comparison.find(r => r.selected);
302
+ if (currentRow) {
303
+ const costStr = currentRow.cost < 0.001 ? '< $0.001' : `$${currentRow.cost.toFixed(4)}`;
304
+ options.push({
305
+ value: currentRow.model,
306
+ label: `Proceed with ${currentRow.model} (${costStr})`,
307
+ });
308
+ }
309
+
310
+ // Alternatives
311
+ for (const row of comparison) {
312
+ if (row.selected) continue;
313
+ const costStr = row.cost < 0.001 ? '< $0.001' : `$${row.cost.toFixed(4)}`;
314
+ options.push({
315
+ value: row.model,
316
+ label: `Switch to ${row.model} (${costStr})`,
317
+ hint: row.shortFor,
318
+ });
319
+ }
320
+
321
+ // Cancel
322
+ options.push({
323
+ value: '__cancel__',
324
+ label: pc.dim('Cancel'),
325
+ });
326
+
327
+ const choice = await p.select({
328
+ message: 'Choose a model',
329
+ options,
330
+ initialValue: selectedModel,
331
+ });
332
+
333
+ if (p.isCancel(choice) || choice === '__cancel__') {
334
+ p.cancel('Cancelled.');
335
+ return null;
336
+ }
337
+
338
+ return choice;
339
+ }
340
+
341
+ module.exports = {
342
+ estimateTokens,
343
+ estimateTokensForTexts,
344
+ getModelPrice,
345
+ estimateCost,
346
+ estimateCostComparison,
347
+ estimateChatCost,
348
+ estimateLLMCost,
349
+ formatCostEstimate,
350
+ formatChatCostBreakdown,
351
+ confirmOrSwitchModel,
352
+ };