provider-kit 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,319 @@
1
+ import { ProviderError } from './provider-error-adapter.js';
2
+ /**
3
+ * Cohere API 适配器
4
+ *
5
+ * Cohere 使用自己的 API 格式:
6
+ * - Endpoint: /v1/chat
7
+ * - Header: Authorization: Bearer
8
+ * - 请求格式: { message, chat_history, model, ... }
9
+ * - 响应格式: { text, generation_id, ... }
10
+ *
11
+ * 参考文档: https://docs.cohere.com/reference/chat
12
+ */
13
+
14
+ export class CohereAdapter {
15
+ constructor(config) {
16
+ this.id = config.id || 'cohere';
17
+ this.name = config.name || 'Cohere';
18
+ this.nameCn = config.nameCn || 'Cohere';
19
+ this.baseUrl = config.baseUrl || 'https://api.cohere.ai/v1';
20
+ this.apiKey = config.apiKey || null;
21
+ this.defaultModel = config.defaultModel || 'command-r-plus';
22
+ this.models = config.models || [
23
+ 'command-r-plus',
24
+ 'command-r',
25
+ 'command',
26
+ 'command-light',
27
+ 'command-nightly',
28
+ 'command-light-nightly'
29
+ ];
30
+ this.connected = false;
31
+ this.description = config.description || 'Cohere 企业级 AI 模型';
32
+ this.timeout = config.timeout || 60000;
33
+ this.headers = config.headers || {};
34
+ }
35
+
36
+ /**
37
+ * 连接/验证
38
+ */
39
+ async connect(apiKey) {
40
+ if (apiKey) this.apiKey = apiKey;
41
+
42
+ if (!this.apiKey) {
43
+ throw new ProviderError('API Key required for Cohere');
44
+ }
45
+
46
+ this.connected = true;
47
+ return true;
48
+ }
49
+
50
+ /**
51
+ * 断开连接
52
+ */
53
+ disconnect() {
54
+ this.connected = false;
55
+ }
56
+
57
+ /**
58
+ * 转换消息格式: OpenAI -> Cohere
59
+ */
60
+ convertMessages(messages) {
61
+ const preamble = [];
62
+ const chatHistory = [];
63
+ let lastMessage = '';
64
+
65
+ for (let i = 0; i < messages.length; i++) {
66
+ const msg = messages[i];
67
+
68
+ if (msg.role === 'system') {
69
+ preamble.push(msg.content);
70
+ } else if (msg.role === 'user') {
71
+ // 如果是最后一条消息,作为 message 参数
72
+ if (i === messages.length - 1) {
73
+ lastMessage = msg.content;
74
+ } else {
75
+ chatHistory.push({
76
+ role: 'USER',
77
+ message: msg.content
78
+ });
79
+ }
80
+ } else if (msg.role === 'assistant') {
81
+ chatHistory.push({
82
+ role: 'CHATBOT',
83
+ message: msg.content
84
+ });
85
+ }
86
+ }
87
+
88
+ return {
89
+ preamble: preamble.join('\n\n') || undefined,
90
+ chat_history: chatHistory.length > 0 ? chatHistory : undefined,
91
+ message: lastMessage
92
+ };
93
+ }
94
+
95
+ /**
96
+ * 转换响应格式: Cohere -> OpenAI
97
+ */
98
+ convertResponse(cohereResponse) {
99
+ return {
100
+ content: cohereResponse.text || '',
101
+ model: cohereResponse.model || this.defaultModel,
102
+ usage: {
103
+ prompt_tokens: cohereResponse.meta?.billed_units?.input_tokens || 0,
104
+ completion_tokens: cohereResponse.meta?.billed_units?.output_tokens || 0,
105
+ total_tokens: (cohereResponse.meta?.billed_units?.input_tokens || 0) +
106
+ (cohereResponse.meta?.billed_units?.output_tokens || 0)
107
+ },
108
+ raw: cohereResponse
109
+ };
110
+ }
111
+
112
+ /**
113
+ * 发送聊天消息
114
+ */
115
+ async chat(model, messages, options = {}) {
116
+ if (!this.connected) {
117
+ throw new ProviderError('Cohere provider not connected');
118
+ }
119
+
120
+ const url = `${this.baseUrl}/chat`;
121
+
122
+ const { preamble, chat_history, message } = this.convertMessages(messages);
123
+
124
+ const body = {
125
+ model: model || this.defaultModel,
126
+ message,
127
+ chat_history,
128
+ preamble,
129
+ temperature: options.temperature,
130
+ max_tokens: options.max_tokens,
131
+ p: options.top_p,
132
+ k: options.top_k,
133
+ stream: false
134
+ };
135
+
136
+ const headers = {
137
+ 'Content-Type': 'application/json',
138
+ 'Authorization': `Bearer ${this.apiKey}`,
139
+ ...this.headers
140
+ };
141
+
142
+ const response = await fetch(url, {
143
+ method: 'POST',
144
+ headers,
145
+ body: JSON.stringify(body),
146
+ signal: this.timeout ? AbortSignal.timeout(this.timeout) : undefined
147
+ });
148
+
149
+ if (!response.ok) {
150
+ const error = await response.json().catch(() => ({}));
151
+ throw new ProviderError(
152
+ error.message ||
153
+ `Cohere API error: ${response.status} ${response.statusText}`
154
+ );
155
+ }
156
+
157
+ const data = await response.json();
158
+ return this.convertResponse(data);
159
+ }
160
+
161
+ /**
162
+ * 流式聊天
163
+ */
164
+ async *chatStream(model, messages, options = {}) {
165
+ if (!this.connected) {
166
+ throw new ProviderError('Cohere provider not connected');
167
+ }
168
+
169
+ const url = `${this.baseUrl}/chat`;
170
+
171
+ const { preamble, chat_history, message } = this.convertMessages(messages);
172
+
173
+ const body = {
174
+ model: model || this.defaultModel,
175
+ message,
176
+ chat_history,
177
+ preamble,
178
+ temperature: options.temperature,
179
+ max_tokens: options.max_tokens,
180
+ p: options.top_p,
181
+ k: options.top_k,
182
+ stream: true
183
+ };
184
+
185
+ const headers = {
186
+ 'Content-Type': 'application/json',
187
+ 'Authorization': `Bearer ${this.apiKey}`,
188
+ ...this.headers
189
+ };
190
+
191
+ const response = await fetch(url, {
192
+ method: 'POST',
193
+ headers,
194
+ body: JSON.stringify(body),
195
+ signal: this.timeout ? AbortSignal.timeout(this.timeout) : undefined
196
+ });
197
+
198
+ if (!response.ok) {
199
+ const error = await response.json().catch(() => ({}));
200
+ throw new ProviderError(
201
+ error.message ||
202
+ `Cohere API error: ${response.status} ${response.statusText}`
203
+ );
204
+ }
205
+
206
+ const reader = response.body.getReader();
207
+ const decoder = new TextDecoder();
208
+ let buffer = '';
209
+
210
+ while (true) {
211
+ const { done, value } = await reader.read();
212
+ if (done) break;
213
+
214
+ buffer += decoder.decode(value, { stream: true });
215
+ const lines = buffer.split('\n');
216
+ buffer = lines.pop() || '';
217
+
218
+ for (const line of lines) {
219
+ if (line.trim()) {
220
+ try {
221
+ const json = JSON.parse(line);
222
+
223
+ if (json.event_type === 'text-generation') {
224
+ yield { type: 'content', content: json.text, done: false };
225
+ } else if (json.event_type === 'stream-end') {
226
+ yield { done: true };
227
+ return;
228
+ }
229
+ } catch (e) {
230
+ // 忽略解析错误
231
+ }
232
+ }
233
+ }
234
+ }
235
+
236
+ yield { done: true };
237
+ }
238
+
239
+ /**
240
+ * Embedding API
241
+ */
242
+ async embeddings(texts, model = 'embed-english-v3.0') {
243
+ if (!this.connected) {
244
+ throw new ProviderError('Cohere provider not connected');
245
+ }
246
+
247
+ const url = `${this.baseUrl}/embed`;
248
+
249
+ const body = {
250
+ model,
251
+ texts: Array.isArray(texts) ? texts : [texts],
252
+ input_type: 'search_document'
253
+ };
254
+
255
+ const headers = {
256
+ 'Content-Type': 'application/json',
257
+ 'Authorization': `Bearer ${this.apiKey}`,
258
+ ...this.headers
259
+ };
260
+
261
+ const response = await fetch(url, {
262
+ method: 'POST',
263
+ headers,
264
+ body: JSON.stringify(body)
265
+ });
266
+
267
+ if (!response.ok) {
268
+ const error = await response.json().catch(() => ({}));
269
+ throw new ProviderError(error.message || `Embedding API error: ${response.status}`);
270
+ }
271
+
272
+ const data = await response.json();
273
+ return data.embeddings;
274
+ }
275
+
276
+ /**
277
+ * 获取模型列表
278
+ */
279
+ async fetchModels() {
280
+ return this.models;
281
+ }
282
+
283
+ /**
284
+ * 获取模型列表(本地)
285
+ */
286
+ getModels() {
287
+ return this.models;
288
+ }
289
+
290
+ /**
291
+ * 获取状态
292
+ */
293
+ getStatus() {
294
+ return {
295
+ id: this.id,
296
+ name: this.name,
297
+ nameCn: this.nameCn,
298
+ baseUrl: this.baseUrl,
299
+ connected: this.connected,
300
+ modelCount: this.models.length,
301
+ defaultModel: this.defaultModel,
302
+ hasApiKey: !!this.apiKey,
303
+ transport: 'cohere_chat'
304
+ };
305
+ }
306
+ }
307
+
308
+ export function createCohereProvider(apiKey = null, overrides = {}) {
309
+ return new CohereAdapter({
310
+ id: 'cohere',
311
+ name: 'Cohere',
312
+ nameCn: 'Cohere',
313
+ baseUrl: 'https://api.cohere.ai/v1',
314
+ apiKey,
315
+ ...overrides
316
+ });
317
+ }
318
+
319
+ export default CohereAdapter;
@@ -0,0 +1,282 @@
1
+ import { ProviderError } from './provider-error-adapter.js';
2
+ /**
3
+ * Google Gemini API 适配器
4
+ *
5
+ * Google Gemini 使用独特的 API 格式:
6
+ * - Endpoint: /v1beta/models/{model}:generateContent
7
+ * - Header: x-goog-api-key 或 Authorization: Bearer
8
+ * - 请求格式: { contents: [...], generationConfig: {...} }
9
+ * - 响应格式: { candidates: [{ content: { parts: [...] } }] }
10
+ *
11
+ * 参考文档: https://ai.google.dev/api/rest
12
+ */
13
+
14
+ export class GeminiAdapter {
15
+ constructor(config) {
16
+ this.id = config.id || 'gemini';
17
+ this.name = config.name || 'Gemini';
18
+ this.nameCn = config.nameCn || 'Google Gemini';
19
+ this.baseUrl = config.baseUrl || 'https://generativelanguage.googleapis.com/v1beta';
20
+ this.apiKey = config.apiKey || null;
21
+ this.defaultModel = config.defaultModel || 'gemini-2.0-flash-exp';
22
+ this.models = config.models || [
23
+ 'gemini-2.0-flash-exp',
24
+ 'gemini-2.0-flash-thinking-exp-1219',
25
+ 'gemini-1.5-pro',
26
+ 'gemini-1.5-flash',
27
+ 'gemini-1.5-flash-8b'
28
+ ];
29
+ this.connected = false;
30
+ this.description = config.description || 'Google Gemini 系列模型';
31
+ this.timeout = config.timeout || 60000;
32
+ this.headers = config.headers || {};
33
+ }
34
+
35
+ /**
36
+ * 连接/验证
37
+ */
38
+ async connect(apiKey) {
39
+ if (apiKey) this.apiKey = apiKey;
40
+
41
+ if (!this.apiKey) {
42
+ throw new ProviderError('API Key required for Google Gemini');
43
+ }
44
+
45
+ this.connected = true;
46
+ return true;
47
+ }
48
+
49
+ /**
50
+ * 断开连接
51
+ */
52
+ disconnect() {
53
+ this.connected = false;
54
+ }
55
+
56
+ /**
57
+ * 转换消息格式: OpenAI -> Gemini
58
+ */
59
+ convertMessages(messages) {
60
+ const systemInstructions = [];
61
+ const contents = [];
62
+
63
+ for (const msg of messages) {
64
+ if (msg.role === 'system') {
65
+ systemInstructions.push(msg.content);
66
+ } else if (msg.role === 'user') {
67
+ contents.push({
68
+ role: 'user',
69
+ parts: [{ text: msg.content }]
70
+ });
71
+ } else if (msg.role === 'assistant') {
72
+ contents.push({
73
+ role: 'model',
74
+ parts: [{ text: msg.content }]
75
+ });
76
+ }
77
+ }
78
+
79
+ return {
80
+ systemInstruction: systemInstructions.length > 0
81
+ ? { parts: [{ text: systemInstructions.join('\n\n') }] }
82
+ : undefined,
83
+ contents
84
+ };
85
+ }
86
+
87
+ /**
88
+ * 转换响应格式: Gemini -> OpenAI
89
+ */
90
+ convertResponse(geminiResponse) {
91
+ const candidate = geminiResponse.candidates?.[0];
92
+ const content = candidate?.content?.parts?.map(p => p.text).join('') || '';
93
+
94
+ return {
95
+ content,
96
+ model: geminiResponse.modelVersion || this.defaultModel,
97
+ usage: {
98
+ prompt_tokens: geminiResponse.usageMetadata?.promptTokenCount || 0,
99
+ completion_tokens: geminiResponse.usageMetadata?.candidatesTokenCount || 0,
100
+ total_tokens: geminiResponse.usageMetadata?.totalTokenCount || 0
101
+ },
102
+ raw: geminiResponse
103
+ };
104
+ }
105
+
106
+ /**
107
+ * 发送聊天消息
108
+ */
109
+ async chat(model, messages, options = {}) {
110
+ if (!this.connected) {
111
+ throw new ProviderError('Gemini provider not connected');
112
+ }
113
+
114
+ const modelName = model || this.defaultModel;
115
+ const url = `${this.baseUrl}/models/${modelName}:generateContent`;
116
+
117
+ const { systemInstruction, contents } = this.convertMessages(messages);
118
+
119
+ const body = {
120
+ contents,
121
+ generationConfig: {
122
+ temperature: options.temperature,
123
+ topP: options.top_p,
124
+ topK: options.top_k,
125
+ maxOutputTokens: options.max_tokens || 2048
126
+ }
127
+ };
128
+
129
+ if (systemInstruction) {
130
+ body.systemInstruction = systemInstruction;
131
+ }
132
+
133
+ const headers = {
134
+ 'Content-Type': 'application/json',
135
+ 'x-goog-api-key': this.apiKey,
136
+ ...this.headers
137
+ };
138
+
139
+ const response = await fetch(url, {
140
+ method: 'POST',
141
+ headers,
142
+ body: JSON.stringify(body),
143
+ signal: this.timeout ? AbortSignal.timeout(this.timeout) : undefined
144
+ });
145
+
146
+ if (!response.ok) {
147
+ const error = await response.json().catch(() => ({}));
148
+ throw new ProviderError(
149
+ error.error?.message ||
150
+ `Gemini API error: ${response.status} ${response.statusText}`
151
+ );
152
+ }
153
+
154
+ const data = await response.json();
155
+ return this.convertResponse(data);
156
+ }
157
+
158
+ /**
159
+ * 流式聊天
160
+ */
161
+ async *chatStream(model, messages, options = {}) {
162
+ if (!this.connected) {
163
+ throw new ProviderError('Gemini provider not connected');
164
+ }
165
+
166
+ const modelName = model || this.defaultModel;
167
+ const url = `${this.baseUrl}/models/${modelName}:streamGenerateContent?alt=sse`;
168
+
169
+ const { systemInstruction, contents } = this.convertMessages(messages);
170
+
171
+ const body = {
172
+ contents,
173
+ generationConfig: {
174
+ temperature: options.temperature,
175
+ topP: options.top_p,
176
+ topK: options.top_k,
177
+ maxOutputTokens: options.max_tokens || 2048
178
+ }
179
+ };
180
+
181
+ if (systemInstruction) {
182
+ body.systemInstruction = systemInstruction;
183
+ }
184
+
185
+ const headers = {
186
+ 'Content-Type': 'application/json',
187
+ 'x-goog-api-key': this.apiKey,
188
+ ...this.headers
189
+ };
190
+
191
+ const response = await fetch(url, {
192
+ method: 'POST',
193
+ headers,
194
+ body: JSON.stringify(body),
195
+ signal: this.timeout ? AbortSignal.timeout(this.timeout) : undefined
196
+ });
197
+
198
+ if (!response.ok) {
199
+ const error = await response.json().catch(() => ({}));
200
+ throw new ProviderError(
201
+ error.error?.message ||
202
+ `Gemini API error: ${response.status} ${response.statusText}`
203
+ );
204
+ }
205
+
206
+ const reader = response.body.getReader();
207
+ const decoder = new TextDecoder();
208
+ let buffer = '';
209
+
210
+ while (true) {
211
+ const { done, value } = await reader.read();
212
+ if (done) break;
213
+
214
+ buffer += decoder.decode(value, { stream: true });
215
+ const lines = buffer.split('\n');
216
+ buffer = lines.pop() || '';
217
+
218
+ for (const line of lines) {
219
+ if (line.startsWith('data: ')) {
220
+ const data = line.slice(6);
221
+
222
+ try {
223
+ const json = JSON.parse(data);
224
+ const text = json.candidates?.[0]?.content?.parts?.[0]?.text;
225
+
226
+ if (text) {
227
+ yield { type: 'content', content: text, done: false };
228
+ }
229
+ } catch (e) {
230
+ // 忽略解析错误
231
+ }
232
+ }
233
+ }
234
+ }
235
+
236
+ yield { done: true };
237
+ }
238
+
239
+ /**
240
+ * 获取模型列表
241
+ */
242
+ async fetchModels() {
243
+ return this.models;
244
+ }
245
+
246
+ /**
247
+ * 获取模型列表(本地)
248
+ */
249
+ getModels() {
250
+ return this.models;
251
+ }
252
+
253
+ /**
254
+ * 获取状态
255
+ */
256
+ getStatus() {
257
+ return {
258
+ id: this.id,
259
+ name: this.name,
260
+ nameCn: this.nameCn,
261
+ baseUrl: this.baseUrl,
262
+ connected: this.connected,
263
+ modelCount: this.models.length,
264
+ defaultModel: this.defaultModel,
265
+ hasApiKey: !!this.apiKey,
266
+ transport: 'gemini_generate_content'
267
+ };
268
+ }
269
+ }
270
+
271
+ export function createGeminiProvider(apiKey = null, overrides = {}) {
272
+ return new GeminiAdapter({
273
+ id: 'gemini',
274
+ name: 'Gemini',
275
+ nameCn: 'Google Gemini',
276
+ baseUrl: 'https://generativelanguage.googleapis.com/v1beta',
277
+ apiKey,
278
+ ...overrides
279
+ });
280
+ }
281
+
282
+ export default GeminiAdapter;