rax-flow-providers 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/claude-adapter.d.ts +41 -0
- package/dist/claude-adapter.d.ts.map +1 -0
- package/dist/claude-adapter.js +236 -0
- package/dist/claude-adapter.js.map +1 -0
- package/dist/cohere-adapter.d.ts +37 -0
- package/dist/cohere-adapter.d.ts.map +1 -0
- package/dist/cohere-adapter.js +160 -0
- package/dist/cohere-adapter.js.map +1 -0
- package/dist/error-mapper.d.ts +51 -0
- package/dist/error-mapper.d.ts.map +1 -0
- package/dist/error-mapper.js +132 -0
- package/dist/error-mapper.js.map +1 -0
- package/dist/gemini-adapter.d.ts +37 -0
- package/dist/gemini-adapter.d.ts.map +1 -0
- package/dist/gemini-adapter.js +150 -0
- package/dist/gemini-adapter.js.map +1 -0
- package/dist/groq-adapter.d.ts +35 -0
- package/dist/groq-adapter.d.ts.map +1 -0
- package/dist/groq-adapter.js +152 -0
- package/dist/groq-adapter.js.map +1 -0
- package/dist/host-bridge-adapter.d.ts +20 -0
- package/dist/host-bridge-adapter.d.ts.map +1 -0
- package/dist/host-bridge-adapter.js +145 -0
- package/dist/host-bridge-adapter.js.map +1 -0
- package/dist/index.d.ts +12 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +12 -0
- package/dist/index.js.map +1 -0
- package/dist/mistral-adapter.d.ts +39 -0
- package/dist/mistral-adapter.d.ts.map +1 -0
- package/dist/mistral-adapter.js +172 -0
- package/dist/mistral-adapter.js.map +1 -0
- package/dist/openai-adapter.d.ts +30 -0
- package/dist/openai-adapter.d.ts.map +1 -0
- package/dist/openai-adapter.js +171 -0
- package/dist/openai-adapter.js.map +1 -0
- package/dist/pricing.d.ts +15 -0
- package/dist/pricing.d.ts.map +1 -0
- package/dist/pricing.js +61 -0
- package/dist/pricing.js.map +1 -0
- package/dist/rest-adapter.d.ts +32 -0
- package/dist/rest-adapter.d.ts.map +1 -0
- package/dist/rest-adapter.js +124 -0
- package/dist/rest-adapter.js.map +1 -0
- package/dist/strategy.d.ts +38 -0
- package/dist/strategy.d.ts.map +1 -0
- package/dist/strategy.js +117 -0
- package/dist/strategy.js.map +1 -0
- package/dist/utils.d.ts +3 -0
- package/dist/utils.d.ts.map +1 -0
- package/dist/utils.js +22 -0
- package/dist/utils.js.map +1 -0
- package/package.json +18 -0
- package/src/claude-adapter.ts +350 -0
- package/src/cohere-adapter.ts +262 -0
- package/src/error-mapper.ts +187 -0
- package/src/gemini-adapter.ts +246 -0
- package/src/groq-adapter.ts +234 -0
- package/src/host-bridge-adapter.ts +189 -0
- package/src/index.ts +11 -0
- package/src/mistral-adapter.ts +262 -0
- package/src/openai-adapter.ts +240 -0
- package/src/pricing.ts +77 -0
- package/src/rest-adapter.ts +181 -0
- package/src/strategy.ts +166 -0
- package/src/utils.ts +18 -0
- package/tsconfig.json +18 -0
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file claude-adapter.ts
|
|
3
|
+
* Anthropic Claude bridge adapter.
|
|
4
|
+
*
|
|
5
|
+
* API reference: https://docs.anthropic.com/en/api/messages
|
|
6
|
+
*
|
|
7
|
+
* • callModel → POST /v1/messages (plain text, streaming=false)
|
|
8
|
+
* • callStructured → same endpoint using Claude's tool_use mechanism
|
|
9
|
+
* (the structured extraction tool), which is the most
|
|
10
|
+
* reliable way to get valid JSON from Claude without
|
|
11
|
+
* prompt-hacking. Falls back to prompt-only for
|
|
12
|
+
* claude-3-haiku which may not have tool_use.
|
|
13
|
+
*
|
|
14
|
+
* Supported models: claude-opus-4-5, claude-sonnet-4-5,
|
|
15
|
+
* claude-3-7-sonnet-latest, claude-3-5-sonnet-latest,
|
|
16
|
+
* claude-3-5-haiku-latest, claude-3-haiku-20240307
|
|
17
|
+
*/
|
|
18
|
+
|
|
19
|
+
import { IModelProvider, ModelResponse, ProviderCallOptions } from "@rax-flow/core";
|
|
20
|
+
import { asString, parseJsonObjectFromText } from "./utils.js";
|
|
21
|
+
import {
|
|
22
|
+
RaxProviderError,
|
|
23
|
+
mapHttpError,
|
|
24
|
+
mapNetworkError,
|
|
25
|
+
mapParseError,
|
|
26
|
+
} from "./error-mapper.js";
|
|
27
|
+
import { calculateCost } from "./pricing.js";
|
|
28
|
+
|
|
29
|
+
// ─────────────────────────────────────────────────────────────────────────────
|
|
30
|
+
// Wire shapes
|
|
31
|
+
// ─────────────────────────────────────────────────────────────────────────────
|
|
32
|
+
|
|
33
|
+
interface ClaudeMessage {
|
|
34
|
+
role: "user" | "assistant";
|
|
35
|
+
content: string;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
interface ClaudeTool {
|
|
39
|
+
name: string;
|
|
40
|
+
description: string;
|
|
41
|
+
input_schema: object;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
interface ClaudeRequest {
|
|
45
|
+
model: string;
|
|
46
|
+
max_tokens: number;
|
|
47
|
+
temperature?: number;
|
|
48
|
+
messages: ClaudeMessage[];
|
|
49
|
+
tools?: ClaudeTool[];
|
|
50
|
+
tool_choice?: { type: "auto" | "any" | "tool"; name?: string };
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
type ClaudeContentBlock =
|
|
54
|
+
| { type: "text"; text: string }
|
|
55
|
+
| { type: "tool_use"; id: string; name: string; input: unknown };
|
|
56
|
+
|
|
57
|
+
interface ClaudeResponse {
|
|
58
|
+
id?: string;
|
|
59
|
+
content?: ClaudeContentBlock[];
|
|
60
|
+
stop_reason?: string;
|
|
61
|
+
usage?: { input_tokens?: number; output_tokens?: number };
|
|
62
|
+
// Error shapes
|
|
63
|
+
type?: string;
|
|
64
|
+
error?: { type?: string; message?: string };
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
// ─────────────────────────────────────────────────────────────────────────────
|
|
68
|
+
// Adapter
|
|
69
|
+
// ─────────────────────────────────────────────────────────────────────────────
|
|
70
|
+
|
|
71
|
+
const VENDOR = "claude";
|
|
72
|
+
const DEFAULT_MODEL = "claude-3-5-sonnet-latest";
|
|
73
|
+
const DEFAULT_BASE_URL = "https://api.anthropic.com";
|
|
74
|
+
const ANTHROPIC_VERSION = "2023-06-01";
|
|
75
|
+
|
|
76
|
+
/** Models that reliably support tool_use for structured output extraction. */
|
|
77
|
+
const TOOL_USE_MODELS = new Set([
|
|
78
|
+
"claude-opus-4-5",
|
|
79
|
+
"claude-sonnet-4-5",
|
|
80
|
+
"claude-3-7-sonnet-latest",
|
|
81
|
+
"claude-3-5-sonnet-latest",
|
|
82
|
+
"claude-3-5-haiku-latest",
|
|
83
|
+
"claude-3-opus-20240229",
|
|
84
|
+
"claude-3-sonnet-20240229",
|
|
85
|
+
]);
|
|
86
|
+
|
|
87
|
+
export interface ClaudeAdapterOptions {
|
|
88
|
+
apiKey: string;
|
|
89
|
+
baseUrl?: string;
|
|
90
|
+
defaultModel?: string;
|
|
91
|
+
timeoutMs?: number;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
export class ClaudeAdapter implements IModelProvider {
|
|
95
|
+
private readonly apiKey: string;
|
|
96
|
+
private readonly baseUrl: string;
|
|
97
|
+
private readonly defaultModel: string;
|
|
98
|
+
private readonly timeoutMs: number;
|
|
99
|
+
|
|
100
|
+
constructor(options: ClaudeAdapterOptions);
|
|
101
|
+
/** @deprecated Pass an options object instead. */
|
|
102
|
+
constructor(apiKey: string, baseUrl?: string);
|
|
103
|
+
constructor(
|
|
104
|
+
optionsOrApiKey: ClaudeAdapterOptions | string,
|
|
105
|
+
legacyBaseUrl?: string
|
|
106
|
+
) {
|
|
107
|
+
if (typeof optionsOrApiKey === "string") {
|
|
108
|
+
this.apiKey = optionsOrApiKey;
|
|
109
|
+
this.baseUrl = legacyBaseUrl ?? DEFAULT_BASE_URL;
|
|
110
|
+
this.defaultModel = DEFAULT_MODEL;
|
|
111
|
+
this.timeoutMs = 30_000;
|
|
112
|
+
} else {
|
|
113
|
+
this.apiKey = optionsOrApiKey.apiKey;
|
|
114
|
+
this.baseUrl = optionsOrApiKey.baseUrl ?? DEFAULT_BASE_URL;
|
|
115
|
+
this.defaultModel = optionsOrApiKey.defaultModel ?? DEFAULT_MODEL;
|
|
116
|
+
this.timeoutMs = optionsOrApiKey.timeoutMs ?? 30_000;
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
// ── private helpers ────────────────────────────────────────────────────────
|
|
121
|
+
|
|
122
|
+
private supportsToolUse(model: string): boolean {
|
|
123
|
+
for (const known of TOOL_USE_MODELS) {
|
|
124
|
+
if (model === known || model.startsWith(`${known}-`) || model.startsWith("claude-3-5") || model.startsWith("claude-3-7") || model.startsWith("claude-opus") || model.startsWith("claude-sonnet")) {
|
|
125
|
+
return true;
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
return false;
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
private async post(
|
|
132
|
+
body: ClaudeRequest,
|
|
133
|
+
context: "callModel" | "callStructured"
|
|
134
|
+
): Promise<ClaudeResponse> {
|
|
135
|
+
const controller = new AbortController();
|
|
136
|
+
const timer = setTimeout(() => controller.abort(), this.timeoutMs);
|
|
137
|
+
|
|
138
|
+
let res: Response;
|
|
139
|
+
try {
|
|
140
|
+
res = await fetch(`${this.baseUrl}/v1/messages`, {
|
|
141
|
+
method: "POST",
|
|
142
|
+
headers: {
|
|
143
|
+
"x-api-key": this.apiKey,
|
|
144
|
+
"anthropic-version": ANTHROPIC_VERSION,
|
|
145
|
+
"content-type": "application/json",
|
|
146
|
+
},
|
|
147
|
+
body: JSON.stringify(body),
|
|
148
|
+
signal: controller.signal,
|
|
149
|
+
});
|
|
150
|
+
} catch (err) {
|
|
151
|
+
clearTimeout(timer);
|
|
152
|
+
throw mapNetworkError(VENDOR, err);
|
|
153
|
+
} finally {
|
|
154
|
+
clearTimeout(timer);
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
const raw = await res.json().catch(() => ({}));
|
|
158
|
+
|
|
159
|
+
if (!res.ok) {
|
|
160
|
+
throw mapHttpError(VENDOR, res.status, raw, context);
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
const payload = raw as ClaudeResponse;
|
|
164
|
+
|
|
165
|
+
// Anthropic embeds errors inside 200 (type: "error")
|
|
166
|
+
if (payload.type === "error" && payload.error) {
|
|
167
|
+
const errType = payload.error.type ?? "";
|
|
168
|
+
const errMsg = payload.error.message ?? "unknown claude error";
|
|
169
|
+
const code =
|
|
170
|
+
errType === "overloaded_error" ? "server_error" :
|
|
171
|
+
errType === "rate_limit_error" ? "auth_quota_exceeded" :
|
|
172
|
+
errType === "authentication_error" ? "auth_invalid" :
|
|
173
|
+
(errType.includes("content") || errType.includes("safety")) ? "content_filtered" :
|
|
174
|
+
"unknown";
|
|
175
|
+
throw new RaxProviderError(VENDOR, code, errMsg, { raw: payload });
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
return payload;
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
private extractText(payload: ClaudeResponse): string {
|
|
182
|
+
return (
|
|
183
|
+
payload.content
|
|
184
|
+
?.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
|
185
|
+
.map((c) => c.text)
|
|
186
|
+
.join("\n") ?? ""
|
|
187
|
+
);
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
private extractToolInput(payload: ClaudeResponse): unknown | null {
|
|
191
|
+
const toolBlock = payload.content?.find(
|
|
192
|
+
(c): c is { type: "tool_use"; id: string; name: string; input: unknown } =>
|
|
193
|
+
c.type === "tool_use"
|
|
194
|
+
);
|
|
195
|
+
return toolBlock?.input ?? null;
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
// ── IModelProvider ─────────────────────────────────────────────────────────
|
|
199
|
+
|
|
200
|
+
async callModel(
|
|
201
|
+
prompt: string,
|
|
202
|
+
options?: ProviderCallOptions
|
|
203
|
+
): Promise<ModelResponse<string>> {
|
|
204
|
+
const started = Date.now();
|
|
205
|
+
const model = options?.model ?? this.defaultModel;
|
|
206
|
+
|
|
207
|
+
const payload = await this.post(
|
|
208
|
+
{
|
|
209
|
+
model,
|
|
210
|
+
max_tokens: options?.maxTokens ?? 1200,
|
|
211
|
+
temperature: options?.temperature ?? 0.2,
|
|
212
|
+
messages: [{ role: "user", content: prompt }],
|
|
213
|
+
},
|
|
214
|
+
"callModel"
|
|
215
|
+
);
|
|
216
|
+
|
|
217
|
+
const usage = payload.usage
|
|
218
|
+
? {
|
|
219
|
+
promptTokens: payload.usage.input_tokens ?? 0,
|
|
220
|
+
completionTokens: payload.usage.output_tokens ?? 0,
|
|
221
|
+
totalTokens: (payload.usage.input_tokens ?? 0) + (payload.usage.output_tokens ?? 0),
|
|
222
|
+
}
|
|
223
|
+
: undefined;
|
|
224
|
+
|
|
225
|
+
return {
|
|
226
|
+
provider: VENDOR,
|
|
227
|
+
model,
|
|
228
|
+
latencyMs: Date.now() - started,
|
|
229
|
+
costUsd: calculateCost(model, usage),
|
|
230
|
+
usage,
|
|
231
|
+
output: this.extractText(payload),
|
|
232
|
+
raw: payload,
|
|
233
|
+
};
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
async callStructured<T>(
|
|
237
|
+
prompt: string,
|
|
238
|
+
schema: object,
|
|
239
|
+
options?: ProviderCallOptions
|
|
240
|
+
): Promise<ModelResponse<T>> {
|
|
241
|
+
const started = Date.now();
|
|
242
|
+
const model = options?.model ?? this.defaultModel;
|
|
243
|
+
|
|
244
|
+
if (this.supportsToolUse(model)) {
|
|
245
|
+
// ── Tool-use path (most reliable for structured output) ────────────────
|
|
246
|
+
const extractionTool: ClaudeTool = {
|
|
247
|
+
name: "structured_output",
|
|
248
|
+
description: "Extract the required structured data and return it as JSON.",
|
|
249
|
+
input_schema: schema,
|
|
250
|
+
};
|
|
251
|
+
|
|
252
|
+
const payload = await this.post(
|
|
253
|
+
{
|
|
254
|
+
model,
|
|
255
|
+
max_tokens: options?.maxTokens ?? 1400,
|
|
256
|
+
temperature: options?.temperature ?? 0,
|
|
257
|
+
messages: [{ role: "user", content: prompt }],
|
|
258
|
+
tools: [extractionTool],
|
|
259
|
+
tool_choice: { type: "tool", name: "structured_output" },
|
|
260
|
+
},
|
|
261
|
+
"callStructured"
|
|
262
|
+
);
|
|
263
|
+
|
|
264
|
+
const usage = payload.usage
|
|
265
|
+
? {
|
|
266
|
+
promptTokens: payload.usage.input_tokens ?? 0,
|
|
267
|
+
completionTokens: payload.usage.output_tokens ?? 0,
|
|
268
|
+
totalTokens: (payload.usage.input_tokens ?? 0) + (payload.usage.output_tokens ?? 0),
|
|
269
|
+
}
|
|
270
|
+
: undefined;
|
|
271
|
+
|
|
272
|
+
const toolInput = this.extractToolInput(payload);
|
|
273
|
+
if (toolInput && typeof toolInput === "object") {
|
|
274
|
+
return {
|
|
275
|
+
provider: VENDOR,
|
|
276
|
+
model,
|
|
277
|
+
latencyMs: Date.now() - started,
|
|
278
|
+
costUsd: calculateCost(model, usage),
|
|
279
|
+
usage,
|
|
280
|
+
output: toolInput as T,
|
|
281
|
+
raw: payload,
|
|
282
|
+
};
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
// If tool_use block missing, fall through to text parsing
|
|
286
|
+
const textFallback = this.extractText(payload);
|
|
287
|
+
const parsed = parseJsonObjectFromText<T>(textFallback);
|
|
288
|
+
if (parsed) {
|
|
289
|
+
return {
|
|
290
|
+
provider: VENDOR,
|
|
291
|
+
model,
|
|
292
|
+
latencyMs: Date.now() - started,
|
|
293
|
+
costUsd: calculateCost(model, usage),
|
|
294
|
+
usage,
|
|
295
|
+
output: parsed,
|
|
296
|
+
raw: payload,
|
|
297
|
+
};
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
throw mapParseError(VENDOR, "callStructured", textFallback);
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
// ── Prompt-only path (claude-3-haiku and legacy models) ─────────────────
|
|
304
|
+
const strictPrompt = [
|
|
305
|
+
prompt,
|
|
306
|
+
"",
|
|
307
|
+
"Return ONLY valid JSON matching this schema (no other text):",
|
|
308
|
+
JSON.stringify(schema, null, 2),
|
|
309
|
+
].join("\n");
|
|
310
|
+
|
|
311
|
+
const payload = await this.post(
|
|
312
|
+
{
|
|
313
|
+
model,
|
|
314
|
+
max_tokens: options?.maxTokens ?? 1400,
|
|
315
|
+
temperature: 0,
|
|
316
|
+
messages: [{ role: "user", content: strictPrompt }],
|
|
317
|
+
},
|
|
318
|
+
"callStructured"
|
|
319
|
+
);
|
|
320
|
+
|
|
321
|
+
const usage = payload.usage
|
|
322
|
+
? {
|
|
323
|
+
promptTokens: payload.usage.input_tokens ?? 0,
|
|
324
|
+
completionTokens: payload.usage.output_tokens ?? 0,
|
|
325
|
+
totalTokens: (payload.usage.input_tokens ?? 0) + (payload.usage.output_tokens ?? 0),
|
|
326
|
+
}
|
|
327
|
+
: undefined;
|
|
328
|
+
|
|
329
|
+
const text = this.extractText(payload);
|
|
330
|
+
const parsed = parseJsonObjectFromText<T>(asString(text));
|
|
331
|
+
|
|
332
|
+
if (!parsed) {
|
|
333
|
+
throw mapParseError(VENDOR, "callStructured", text);
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
return {
|
|
337
|
+
provider: VENDOR,
|
|
338
|
+
model,
|
|
339
|
+
latencyMs: Date.now() - started,
|
|
340
|
+
costUsd: calculateCost(model, usage),
|
|
341
|
+
usage,
|
|
342
|
+
output: parsed,
|
|
343
|
+
raw: payload,
|
|
344
|
+
};
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
async healthCheck(): Promise<boolean> {
|
|
348
|
+
return Boolean(this.apiKey && this.baseUrl);
|
|
349
|
+
}
|
|
350
|
+
}
|
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file cohere-adapter.ts
|
|
3
|
+
* Cohere bridge adapter (Command R / Command R+).
|
|
4
|
+
*
|
|
5
|
+
* API reference: https://docs.cohere.com/reference/chat
|
|
6
|
+
*
|
|
7
|
+
* • callModel → POST /v2/chat (v2 chat API — streaming=false)
|
|
8
|
+
* • callStructured → same endpoint with response_format="json_object"
|
|
9
|
+
* and a JSON schema injected via system_prompt
|
|
10
|
+
*
|
|
11
|
+
* Supported models: command-r-plus-08-2024, command-r-08-2024,
|
|
12
|
+
* command-r7b-12-2024, command-a-03-2025, etc.
|
|
13
|
+
*
|
|
14
|
+
* Notes:
|
|
15
|
+
* - Cohere v2 uses "message" field not "messages" for single-turn.
|
|
16
|
+
* - Structured JSON is supported via response_format + preamble (system).
|
|
17
|
+
*/
|
|
18
|
+
|
|
19
|
+
import { IModelProvider, ModelResponse, ProviderCallOptions } from "@rax-flow/core";
|
|
20
|
+
import { parseJsonObjectFromText } from "./utils.js";
|
|
21
|
+
import {
|
|
22
|
+
RaxProviderError,
|
|
23
|
+
mapHttpError,
|
|
24
|
+
mapNetworkError,
|
|
25
|
+
mapParseError,
|
|
26
|
+
} from "./error-mapper.js";
|
|
27
|
+
import { calculateCost } from "./pricing.js";
|
|
28
|
+
|
|
29
|
+
// ─────────────────────────────────────────────────────────────────────────────
|
|
30
|
+
// Cohere v2 wire shapes
|
|
31
|
+
// ─────────────────────────────────────────────────────────────────────────────
|
|
32
|
+
|
|
33
|
+
interface CohereMessage {
|
|
34
|
+
role: "system" | "user" | "assistant";
|
|
35
|
+
content: string;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
interface CohereChatRequestV2 {
|
|
39
|
+
model: string;
|
|
40
|
+
messages: CohereMessage[];
|
|
41
|
+
temperature?: number;
|
|
42
|
+
max_tokens?: number;
|
|
43
|
+
response_format?: {
|
|
44
|
+
type: "text" | "json_object";
|
|
45
|
+
json_schema?: object;
|
|
46
|
+
};
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
interface CohereContent {
|
|
50
|
+
type: string;
|
|
51
|
+
text?: string;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
interface CohereChatResponseV2 {
|
|
55
|
+
id?: string;
|
|
56
|
+
message?: {
|
|
57
|
+
role?: string;
|
|
58
|
+
content?: CohereContent[];
|
|
59
|
+
};
|
|
60
|
+
finish_reason?: string;
|
|
61
|
+
usage?: {
|
|
62
|
+
billed_units?: { input_tokens?: number; output_tokens?: number };
|
|
63
|
+
};
|
|
64
|
+
// Error body
|
|
65
|
+
message_str?: string; // top-level error message field in v2 errors
|
|
66
|
+
error?: string;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
// ─────────────────────────────────────────────────────────────────────────────
|
|
70
|
+
// Adapter
|
|
71
|
+
// ─────────────────────────────────────────────────────────────────────────────
|
|
72
|
+
|
|
73
|
+
const VENDOR = "cohere";
|
|
74
|
+
const DEFAULT_MODEL = "command-r-plus-08-2024";
|
|
75
|
+
const BASE_URL = "https://api.cohere.com";
|
|
76
|
+
|
|
77
|
+
export interface CohereAdapterOptions {
|
|
78
|
+
apiKey: string;
|
|
79
|
+
baseUrl?: string;
|
|
80
|
+
defaultModel?: string;
|
|
81
|
+
timeoutMs?: number;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
export class CohereAdapter implements IModelProvider {
|
|
85
|
+
private readonly apiKey: string;
|
|
86
|
+
private readonly baseUrl: string;
|
|
87
|
+
private readonly defaultModel: string;
|
|
88
|
+
private readonly timeoutMs: number;
|
|
89
|
+
|
|
90
|
+
constructor(options: CohereAdapterOptions) {
|
|
91
|
+
this.apiKey = options.apiKey;
|
|
92
|
+
this.baseUrl = options.baseUrl ?? BASE_URL;
|
|
93
|
+
this.defaultModel = options.defaultModel ?? DEFAULT_MODEL;
|
|
94
|
+
this.timeoutMs = options.timeoutMs ?? 30_000;
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
// ── private helpers ────────────────────────────────────────────────────────
|
|
98
|
+
|
|
99
|
+
private async post(
|
|
100
|
+
body: CohereChatRequestV2,
|
|
101
|
+
context: "callModel" | "callStructured"
|
|
102
|
+
): Promise<CohereChatResponseV2> {
|
|
103
|
+
const controller = new AbortController();
|
|
104
|
+
const timer = setTimeout(() => controller.abort(), this.timeoutMs);
|
|
105
|
+
|
|
106
|
+
let res: Response;
|
|
107
|
+
try {
|
|
108
|
+
res = await fetch(`${this.baseUrl}/v2/chat`, {
|
|
109
|
+
method: "POST",
|
|
110
|
+
headers: {
|
|
111
|
+
Authorization: `Bearer ${this.apiKey}`,
|
|
112
|
+
"Content-Type": "application/json",
|
|
113
|
+
Accept: "application/json",
|
|
114
|
+
},
|
|
115
|
+
body: JSON.stringify(body),
|
|
116
|
+
signal: controller.signal,
|
|
117
|
+
});
|
|
118
|
+
} catch (err) {
|
|
119
|
+
clearTimeout(timer);
|
|
120
|
+
throw mapNetworkError(VENDOR, err);
|
|
121
|
+
} finally {
|
|
122
|
+
clearTimeout(timer);
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
const raw = await res.json().catch(() => ({}));
|
|
126
|
+
|
|
127
|
+
if (!res.ok) {
|
|
128
|
+
throw mapHttpError(VENDOR, res.status, raw, context);
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
const payload = raw as CohereChatResponseV2;
|
|
132
|
+
|
|
133
|
+
// Cohere finish reasons that signal a safety block
|
|
134
|
+
if (
|
|
135
|
+
payload.finish_reason === "MAX_TOKENS" &&
|
|
136
|
+
!payload.message?.content?.length
|
|
137
|
+
) {
|
|
138
|
+
throw new RaxProviderError(
|
|
139
|
+
VENDOR,
|
|
140
|
+
"content_filtered",
|
|
141
|
+
"Cohere response truncated without content — possible safety block",
|
|
142
|
+
{ raw: payload }
|
|
143
|
+
);
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
return payload;
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
private extractText(payload: CohereChatResponseV2): string {
|
|
150
|
+
return (
|
|
151
|
+
payload.message?.content
|
|
152
|
+
?.filter((c) => c.type === "text")
|
|
153
|
+
.map((c) => c.text ?? "")
|
|
154
|
+
.join("\n") ?? ""
|
|
155
|
+
);
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
// ── IModelProvider ─────────────────────────────────────────────────────────
|
|
159
|
+
|
|
160
|
+
async callModel(
|
|
161
|
+
prompt: string,
|
|
162
|
+
options?: ProviderCallOptions
|
|
163
|
+
): Promise<ModelResponse<string>> {
|
|
164
|
+
const started = Date.now();
|
|
165
|
+
const model = options?.model ?? this.defaultModel;
|
|
166
|
+
|
|
167
|
+
const payload = await this.post(
|
|
168
|
+
{
|
|
169
|
+
model,
|
|
170
|
+
temperature: options?.temperature ?? 0.2,
|
|
171
|
+
max_tokens: options?.maxTokens ?? 1200,
|
|
172
|
+
messages: [{ role: "user", content: prompt }],
|
|
173
|
+
},
|
|
174
|
+
"callModel"
|
|
175
|
+
);
|
|
176
|
+
|
|
177
|
+
const inputTokens = payload.usage?.billed_units?.input_tokens ?? 0;
|
|
178
|
+
const outputTokens = payload.usage?.billed_units?.output_tokens ?? 0;
|
|
179
|
+
const usage = payload.usage?.billed_units
|
|
180
|
+
? {
|
|
181
|
+
promptTokens: inputTokens,
|
|
182
|
+
completionTokens: outputTokens,
|
|
183
|
+
totalTokens: inputTokens + outputTokens,
|
|
184
|
+
}
|
|
185
|
+
: undefined;
|
|
186
|
+
|
|
187
|
+
return {
|
|
188
|
+
provider: VENDOR,
|
|
189
|
+
model,
|
|
190
|
+
latencyMs: Date.now() - started,
|
|
191
|
+
costUsd: calculateCost(model, usage),
|
|
192
|
+
usage,
|
|
193
|
+
output: this.extractText(payload),
|
|
194
|
+
raw: payload,
|
|
195
|
+
};
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
async callStructured<T>(
|
|
199
|
+
prompt: string,
|
|
200
|
+
schema: object,
|
|
201
|
+
options?: ProviderCallOptions
|
|
202
|
+
): Promise<ModelResponse<T>> {
|
|
203
|
+
const started = Date.now();
|
|
204
|
+
const model = options?.model ?? this.defaultModel;
|
|
205
|
+
|
|
206
|
+
// Cohere json_object mode with schema embedded in system message
|
|
207
|
+
const systemContent = [
|
|
208
|
+
"You are a JSON generation assistant.",
|
|
209
|
+
"ALWAYS respond with ONLY a valid JSON object.",
|
|
210
|
+
"The JSON must strictly follow this schema:",
|
|
211
|
+
JSON.stringify(schema, null, 2),
|
|
212
|
+
].join("\n");
|
|
213
|
+
|
|
214
|
+
const payload = await this.post(
|
|
215
|
+
{
|
|
216
|
+
model,
|
|
217
|
+
temperature: options?.temperature ?? 0,
|
|
218
|
+
max_tokens: options?.maxTokens ?? 1400,
|
|
219
|
+
response_format: {
|
|
220
|
+
type: "json_object",
|
|
221
|
+
json_schema: schema,
|
|
222
|
+
},
|
|
223
|
+
messages: [
|
|
224
|
+
{ role: "system", content: systemContent },
|
|
225
|
+
{ role: "user", content: prompt },
|
|
226
|
+
],
|
|
227
|
+
},
|
|
228
|
+
"callStructured"
|
|
229
|
+
);
|
|
230
|
+
|
|
231
|
+
const text = this.extractText(payload);
|
|
232
|
+
const parsed = parseJsonObjectFromText<T>(text);
|
|
233
|
+
|
|
234
|
+
if (!parsed) {
|
|
235
|
+
throw mapParseError(VENDOR, "callStructured", text);
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
const inputTokens = payload.usage?.billed_units?.input_tokens ?? 0;
|
|
239
|
+
const outputTokens = payload.usage?.billed_units?.output_tokens ?? 0;
|
|
240
|
+
const usage = payload.usage?.billed_units
|
|
241
|
+
? {
|
|
242
|
+
promptTokens: inputTokens,
|
|
243
|
+
completionTokens: outputTokens,
|
|
244
|
+
totalTokens: inputTokens + outputTokens,
|
|
245
|
+
}
|
|
246
|
+
: undefined;
|
|
247
|
+
|
|
248
|
+
return {
|
|
249
|
+
provider: VENDOR,
|
|
250
|
+
model,
|
|
251
|
+
latencyMs: Date.now() - started,
|
|
252
|
+
costUsd: calculateCost(model, usage),
|
|
253
|
+
usage,
|
|
254
|
+
output: parsed,
|
|
255
|
+
raw: payload,
|
|
256
|
+
};
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
async healthCheck(): Promise<boolean> {
|
|
260
|
+
return Boolean(this.apiKey && this.baseUrl);
|
|
261
|
+
}
|
|
262
|
+
}
|