@sparkleideas/providers 3.5.2-patch.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,423 @@
1
+ /**
2
+ * V3 Cohere Provider
3
+ *
4
+ * Supports Command R+, Command R, and Command Light models.
5
+ *
6
+ * @module @sparkleideas/providers/cohere-provider
7
+ */
8
+
9
+ import { BaseProvider, BaseProviderOptions } from './base-provider.js';
10
+ import {
11
+ LLMProvider,
12
+ LLMModel,
13
+ LLMRequest,
14
+ LLMResponse,
15
+ LLMStreamEvent,
16
+ ModelInfo,
17
+ ProviderCapabilities,
18
+ HealthCheckResult,
19
+ AuthenticationError,
20
+ RateLimitError,
21
+ LLMProviderError,
22
+ } from './types.js';
23
+
24
+ interface CohereRequest {
25
+ model: string;
26
+ message: string;
27
+ chat_history?: Array<{
28
+ role: 'USER' | 'CHATBOT' | 'SYSTEM';
29
+ message: string;
30
+ }>;
31
+ preamble?: string;
32
+ temperature?: number;
33
+ p?: number;
34
+ k?: number;
35
+ max_tokens?: number;
36
+ stop_sequences?: string[];
37
+ stream?: boolean;
38
+ tools?: Array<{
39
+ name: string;
40
+ description: string;
41
+ parameter_definitions: Record<string, unknown>;
42
+ }>;
43
+ }
44
+
45
+ interface CohereResponse {
46
+ response_id: string;
47
+ text: string;
48
+ generation_id: string;
49
+ chat_history: Array<{ role: string; message: string }>;
50
+ finish_reason: string;
51
+ meta: {
52
+ api_version: { version: string };
53
+ billed_units: {
54
+ input_tokens: number;
55
+ output_tokens: number;
56
+ };
57
+ };
58
+ tool_calls?: Array<{
59
+ name: string;
60
+ parameters: unknown;
61
+ }>;
62
+ }
63
+
64
+ export class CohereProvider extends BaseProvider {
65
+ readonly name: LLMProvider = 'cohere';
66
+ readonly capabilities: ProviderCapabilities = {
67
+ supportedModels: [
68
+ 'command-r-plus',
69
+ 'command-r',
70
+ 'command-light',
71
+ 'command',
72
+ ],
73
+ maxContextLength: {
74
+ 'command-r-plus': 128000,
75
+ 'command-r': 128000,
76
+ 'command-light': 4096,
77
+ 'command': 4096,
78
+ },
79
+ maxOutputTokens: {
80
+ 'command-r-plus': 4096,
81
+ 'command-r': 4096,
82
+ 'command-light': 4096,
83
+ 'command': 4096,
84
+ },
85
+ supportsStreaming: true,
86
+ supportsToolCalling: true,
87
+ supportsSystemMessages: true,
88
+ supportsVision: false,
89
+ supportsAudio: false,
90
+ supportsFineTuning: true,
91
+ supportsEmbeddings: true,
92
+ supportsBatching: false,
93
+ rateLimit: {
94
+ requestsPerMinute: 1000,
95
+ tokensPerMinute: 100000,
96
+ concurrentRequests: 100,
97
+ },
98
+ pricing: {
99
+ 'command-r-plus': {
100
+ promptCostPer1k: 0.003,
101
+ completionCostPer1k: 0.015,
102
+ currency: 'USD',
103
+ },
104
+ 'command-r': {
105
+ promptCostPer1k: 0.0005,
106
+ completionCostPer1k: 0.0015,
107
+ currency: 'USD',
108
+ },
109
+ 'command-light': {
110
+ promptCostPer1k: 0.0003,
111
+ completionCostPer1k: 0.0006,
112
+ currency: 'USD',
113
+ },
114
+ 'command': {
115
+ promptCostPer1k: 0.001,
116
+ completionCostPer1k: 0.002,
117
+ currency: 'USD',
118
+ },
119
+ },
120
+ };
121
+
122
+ private baseUrl: string = 'https://api.cohere.ai/v1';
123
+ private headers: Record<string, string> = {};
124
+
125
+ constructor(options: BaseProviderOptions) {
126
+ super(options);
127
+ }
128
+
129
+ protected async doInitialize(): Promise<void> {
130
+ if (!this.config.apiKey) {
131
+ throw new AuthenticationError('Cohere API key is required', 'cohere');
132
+ }
133
+
134
+ this.baseUrl = this.config.apiUrl || 'https://api.cohere.ai/v1';
135
+ this.headers = {
136
+ Authorization: `Bearer ${this.config.apiKey}`,
137
+ 'Content-Type': 'application/json',
138
+ };
139
+ }
140
+
141
+ protected async doComplete(request: LLMRequest): Promise<LLMResponse> {
142
+ const cohereRequest = this.buildRequest(request);
143
+
144
+ const controller = new AbortController();
145
+ const timeout = setTimeout(() => controller.abort(), this.config.timeout || 60000);
146
+
147
+ try {
148
+ const response = await fetch(`${this.baseUrl}/chat`, {
149
+ method: 'POST',
150
+ headers: this.headers,
151
+ body: JSON.stringify(cohereRequest),
152
+ signal: controller.signal,
153
+ });
154
+
155
+ clearTimeout(timeout);
156
+
157
+ if (!response.ok) {
158
+ await this.handleErrorResponse(response);
159
+ }
160
+
161
+ const data = await response.json() as CohereResponse;
162
+ return this.transformResponse(data, request);
163
+ } catch (error) {
164
+ clearTimeout(timeout);
165
+ throw this.transformError(error);
166
+ }
167
+ }
168
+
169
+ protected async *doStreamComplete(request: LLMRequest): AsyncIterable<LLMStreamEvent> {
170
+ const cohereRequest = this.buildRequest(request, true);
171
+
172
+ const controller = new AbortController();
173
+ const timeout = setTimeout(() => controller.abort(), (this.config.timeout || 60000) * 2);
174
+
175
+ try {
176
+ const response = await fetch(`${this.baseUrl}/chat`, {
177
+ method: 'POST',
178
+ headers: this.headers,
179
+ body: JSON.stringify(cohereRequest),
180
+ signal: controller.signal,
181
+ });
182
+
183
+ if (!response.ok) {
184
+ await this.handleErrorResponse(response);
185
+ }
186
+
187
+ const reader = response.body!.getReader();
188
+ const decoder = new TextDecoder();
189
+ let buffer = '';
190
+ let inputTokens = 0;
191
+ let outputTokens = 0;
192
+
193
+ while (true) {
194
+ const { done, value } = await reader.read();
195
+ if (done) break;
196
+
197
+ buffer += decoder.decode(value, { stream: true });
198
+ const lines = buffer.split('\n');
199
+ buffer = lines.pop() || '';
200
+
201
+ for (const line of lines) {
202
+ if (!line.trim()) continue;
203
+
204
+ try {
205
+ const event = JSON.parse(line);
206
+
207
+ if (event.event_type === 'text-generation' && event.text) {
208
+ yield {
209
+ type: 'content',
210
+ delta: { content: event.text },
211
+ };
212
+ } else if (event.event_type === 'stream-end') {
213
+ if (event.response?.meta?.billed_units) {
214
+ inputTokens = event.response.meta.billed_units.input_tokens;
215
+ outputTokens = event.response.meta.billed_units.output_tokens;
216
+ }
217
+
218
+ const model = request.model || this.config.model;
219
+ const pricing = this.capabilities.pricing[model];
220
+
221
+ yield {
222
+ type: 'done',
223
+ usage: {
224
+ promptTokens: inputTokens,
225
+ completionTokens: outputTokens,
226
+ totalTokens: inputTokens + outputTokens,
227
+ },
228
+ cost: {
229
+ promptCost: (inputTokens / 1000) * pricing.promptCostPer1k,
230
+ completionCost: (outputTokens / 1000) * pricing.completionCostPer1k,
231
+ totalCost:
232
+ (inputTokens / 1000) * pricing.promptCostPer1k +
233
+ (outputTokens / 1000) * pricing.completionCostPer1k,
234
+ currency: 'USD',
235
+ },
236
+ };
237
+ }
238
+ } catch {
239
+ // Ignore parse errors
240
+ }
241
+ }
242
+ }
243
+ } catch (error) {
244
+ clearTimeout(timeout);
245
+ throw this.transformError(error);
246
+ } finally {
247
+ clearTimeout(timeout);
248
+ }
249
+ }
250
+
251
+ async listModels(): Promise<LLMModel[]> {
252
+ return this.capabilities.supportedModels;
253
+ }
254
+
255
+ async getModelInfo(model: LLMModel): Promise<ModelInfo> {
256
+ const descriptions: Record<string, string> = {
257
+ 'command-r-plus': 'Most capable Cohere model with 128K context',
258
+ 'command-r': 'Balanced Cohere model with 128K context',
259
+ 'command-light': 'Fast and efficient Cohere model',
260
+ 'command': 'Standard Cohere model',
261
+ };
262
+
263
+ return {
264
+ model,
265
+ name: model,
266
+ description: descriptions[model] || 'Cohere language model',
267
+ contextLength: this.capabilities.maxContextLength[model] || 4096,
268
+ maxOutputTokens: this.capabilities.maxOutputTokens[model] || 4096,
269
+ supportedFeatures: ['chat', 'completion', 'tool_calling', 'rag'],
270
+ pricing: this.capabilities.pricing[model],
271
+ };
272
+ }
273
+
274
+ protected async doHealthCheck(): Promise<HealthCheckResult> {
275
+ try {
276
+ const response = await fetch(`${this.baseUrl}/check-api-key`, {
277
+ method: 'POST',
278
+ headers: this.headers,
279
+ });
280
+
281
+ return {
282
+ healthy: response.ok,
283
+ timestamp: new Date(),
284
+ ...(response.ok ? {} : { error: `HTTP ${response.status}` }),
285
+ };
286
+ } catch (error) {
287
+ return {
288
+ healthy: false,
289
+ error: error instanceof Error ? error.message : 'Unknown error',
290
+ timestamp: new Date(),
291
+ };
292
+ }
293
+ }
294
+
295
+ private buildRequest(request: LLMRequest, stream = false): CohereRequest {
296
+ // Get the last user message
297
+ const lastUserMessage = [...request.messages].reverse().find((m) => m.role === 'user');
298
+ const systemMessage = request.messages.find((m) => m.role === 'system');
299
+
300
+ // Build chat history (exclude last user message)
301
+ const chatHistory = request.messages
302
+ .filter((m) => m !== lastUserMessage && m.role !== 'system')
303
+ .map((msg) => ({
304
+ role: msg.role === 'assistant' ? 'CHATBOT' as const : 'USER' as const,
305
+ message: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content),
306
+ }));
307
+
308
+ const cohereRequest: CohereRequest = {
309
+ model: request.model || this.config.model,
310
+ message: lastUserMessage
311
+ ? (typeof lastUserMessage.content === 'string'
312
+ ? lastUserMessage.content
313
+ : JSON.stringify(lastUserMessage.content))
314
+ : '',
315
+ stream,
316
+ };
317
+
318
+ if (chatHistory.length > 0) {
319
+ cohereRequest.chat_history = chatHistory;
320
+ }
321
+
322
+ if (systemMessage) {
323
+ cohereRequest.preamble = typeof systemMessage.content === 'string'
324
+ ? systemMessage.content
325
+ : JSON.stringify(systemMessage.content);
326
+ }
327
+
328
+ if (request.temperature !== undefined || this.config.temperature !== undefined) {
329
+ cohereRequest.temperature = request.temperature ?? this.config.temperature;
330
+ }
331
+ if (request.topP !== undefined || this.config.topP !== undefined) {
332
+ cohereRequest.p = request.topP ?? this.config.topP;
333
+ }
334
+ if (request.topK !== undefined || this.config.topK !== undefined) {
335
+ cohereRequest.k = request.topK ?? this.config.topK;
336
+ }
337
+ if (request.maxTokens || this.config.maxTokens) {
338
+ cohereRequest.max_tokens = request.maxTokens || this.config.maxTokens;
339
+ }
340
+ if (request.stopSequences || this.config.stopSequences) {
341
+ cohereRequest.stop_sequences = request.stopSequences || this.config.stopSequences;
342
+ }
343
+
344
+ if (request.tools) {
345
+ cohereRequest.tools = request.tools.map((tool) => ({
346
+ name: tool.function.name,
347
+ description: tool.function.description,
348
+ parameter_definitions: tool.function.parameters.properties as Record<string, unknown>,
349
+ }));
350
+ }
351
+
352
+ return cohereRequest;
353
+ }
354
+
355
+ private transformResponse(data: CohereResponse, request: LLMRequest): LLMResponse {
356
+ const model = request.model || this.config.model;
357
+ const pricing = this.capabilities.pricing[model];
358
+
359
+ const inputTokens = data.meta.billed_units.input_tokens;
360
+ const outputTokens = data.meta.billed_units.output_tokens;
361
+
362
+ const promptCost = (inputTokens / 1000) * pricing.promptCostPer1k;
363
+ const completionCost = (outputTokens / 1000) * pricing.completionCostPer1k;
364
+
365
+ const toolCalls = data.tool_calls?.map((tc) => ({
366
+ id: `tool_${Date.now()}`,
367
+ type: 'function' as const,
368
+ function: {
369
+ name: tc.name,
370
+ arguments: JSON.stringify(tc.parameters),
371
+ },
372
+ }));
373
+
374
+ return {
375
+ id: data.response_id,
376
+ model: model as LLMModel,
377
+ provider: 'cohere',
378
+ content: data.text,
379
+ toolCalls: toolCalls?.length ? toolCalls : undefined,
380
+ usage: {
381
+ promptTokens: inputTokens,
382
+ completionTokens: outputTokens,
383
+ totalTokens: inputTokens + outputTokens,
384
+ },
385
+ cost: {
386
+ promptCost,
387
+ completionCost,
388
+ totalCost: promptCost + completionCost,
389
+ currency: 'USD',
390
+ },
391
+ finishReason: data.finish_reason === 'COMPLETE' ? 'stop' : 'length',
392
+ };
393
+ }
394
+
395
+ private async handleErrorResponse(response: Response): Promise<never> {
396
+ const errorText = await response.text();
397
+ let errorData: { message?: string };
398
+
399
+ try {
400
+ errorData = JSON.parse(errorText);
401
+ } catch {
402
+ errorData = { message: errorText };
403
+ }
404
+
405
+ const message = errorData.message || 'Unknown error';
406
+
407
+ switch (response.status) {
408
+ case 401:
409
+ throw new AuthenticationError(message, 'cohere', errorData);
410
+ case 429:
411
+ throw new RateLimitError(message, 'cohere', undefined, errorData);
412
+ default:
413
+ throw new LLMProviderError(
414
+ message,
415
+ `COHERE_${response.status}`,
416
+ 'cohere',
417
+ response.status,
418
+ response.status >= 500,
419
+ errorData
420
+ );
421
+ }
422
+ }
423
+ }