rax-flow-providers 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/claude-adapter.d.ts +41 -0
- package/dist/claude-adapter.d.ts.map +1 -0
- package/dist/claude-adapter.js +236 -0
- package/dist/claude-adapter.js.map +1 -0
- package/dist/cohere-adapter.d.ts +37 -0
- package/dist/cohere-adapter.d.ts.map +1 -0
- package/dist/cohere-adapter.js +160 -0
- package/dist/cohere-adapter.js.map +1 -0
- package/dist/error-mapper.d.ts +51 -0
- package/dist/error-mapper.d.ts.map +1 -0
- package/dist/error-mapper.js +132 -0
- package/dist/error-mapper.js.map +1 -0
- package/dist/gemini-adapter.d.ts +37 -0
- package/dist/gemini-adapter.d.ts.map +1 -0
- package/dist/gemini-adapter.js +150 -0
- package/dist/gemini-adapter.js.map +1 -0
- package/dist/groq-adapter.d.ts +35 -0
- package/dist/groq-adapter.d.ts.map +1 -0
- package/dist/groq-adapter.js +152 -0
- package/dist/groq-adapter.js.map +1 -0
- package/dist/host-bridge-adapter.d.ts +20 -0
- package/dist/host-bridge-adapter.d.ts.map +1 -0
- package/dist/host-bridge-adapter.js +145 -0
- package/dist/host-bridge-adapter.js.map +1 -0
- package/dist/index.d.ts +12 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +12 -0
- package/dist/index.js.map +1 -0
- package/dist/mistral-adapter.d.ts +39 -0
- package/dist/mistral-adapter.d.ts.map +1 -0
- package/dist/mistral-adapter.js +172 -0
- package/dist/mistral-adapter.js.map +1 -0
- package/dist/openai-adapter.d.ts +30 -0
- package/dist/openai-adapter.d.ts.map +1 -0
- package/dist/openai-adapter.js +171 -0
- package/dist/openai-adapter.js.map +1 -0
- package/dist/pricing.d.ts +15 -0
- package/dist/pricing.d.ts.map +1 -0
- package/dist/pricing.js +61 -0
- package/dist/pricing.js.map +1 -0
- package/dist/rest-adapter.d.ts +32 -0
- package/dist/rest-adapter.d.ts.map +1 -0
- package/dist/rest-adapter.js +124 -0
- package/dist/rest-adapter.js.map +1 -0
- package/dist/strategy.d.ts +38 -0
- package/dist/strategy.d.ts.map +1 -0
- package/dist/strategy.js +117 -0
- package/dist/strategy.js.map +1 -0
- package/dist/utils.d.ts +3 -0
- package/dist/utils.d.ts.map +1 -0
- package/dist/utils.js +22 -0
- package/dist/utils.js.map +1 -0
- package/package.json +18 -0
- package/src/claude-adapter.ts +350 -0
- package/src/cohere-adapter.ts +262 -0
- package/src/error-mapper.ts +187 -0
- package/src/gemini-adapter.ts +246 -0
- package/src/groq-adapter.ts +234 -0
- package/src/host-bridge-adapter.ts +189 -0
- package/src/index.ts +11 -0
- package/src/mistral-adapter.ts +262 -0
- package/src/openai-adapter.ts +240 -0
- package/src/pricing.ts +77 -0
- package/src/rest-adapter.ts +181 -0
- package/src/strategy.ts +166 -0
- package/src/utils.ts +18 -0
- package/tsconfig.json +18 -0
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
import { spawn } from "node:child_process";
|
|
2
|
+
import { IModelProvider, ModelResponse, ProviderCallOptions } from "@rax-flow/core";
|
|
3
|
+
import { parseJsonObjectFromText } from "./utils.js";
|
|
4
|
+
|
|
5
|
+
type HostBridgeFn = (payload: HostBridgeRequest) => Promise<HostBridgeResponse>;
|
|
6
|
+
|
|
7
|
+
interface HostBridgeRequest {
|
|
8
|
+
version: 1;
|
|
9
|
+
action: "callModel" | "callStructured";
|
|
10
|
+
prompt: string;
|
|
11
|
+
schema?: object;
|
|
12
|
+
options?: ProviderCallOptions;
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
interface HostBridgeResponse {
|
|
16
|
+
ok: boolean;
|
|
17
|
+
output?: unknown;
|
|
18
|
+
model?: string;
|
|
19
|
+
latencyMs?: number;
|
|
20
|
+
error?: string;
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
interface HostBridgeAdapterOptions {
|
|
24
|
+
model?: string;
|
|
25
|
+
mode?: "auto" | "bridge-only" | "mock";
|
|
26
|
+
command?: string;
|
|
27
|
+
timeoutMs?: number;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
function parseShellCommand(input: string): { command: string; args: string[] } {
|
|
31
|
+
const parts = input.match(/(?:[^\s"]+|"[^"]*")+/g) ?? [];
|
|
32
|
+
if (parts.length === 0) {
|
|
33
|
+
throw new Error("invalid_bridge_command");
|
|
34
|
+
}
|
|
35
|
+
const [command, ...rawArgs] = parts;
|
|
36
|
+
const args = rawArgs.map((arg) => arg.replace(/^"|"$/g, ""));
|
|
37
|
+
// `command` is defined here because parts.length > 0 is guaranteed above
|
|
38
|
+
return { command: command!, args };
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
async function runBridgeCommand(commandSpec: string, payload: HostBridgeRequest, timeoutMs: number): Promise<HostBridgeResponse> {
|
|
42
|
+
const { command, args } = parseShellCommand(commandSpec);
|
|
43
|
+
return await new Promise<HostBridgeResponse>((resolve, reject) => {
|
|
44
|
+
const child = spawn(command, args, { stdio: ["pipe", "pipe", "pipe"] });
|
|
45
|
+
let stdout = "";
|
|
46
|
+
let stderr = "";
|
|
47
|
+
|
|
48
|
+
const timer = setTimeout(() => {
|
|
49
|
+
child.kill("SIGKILL");
|
|
50
|
+
reject(new Error("host_bridge_timeout"));
|
|
51
|
+
}, timeoutMs);
|
|
52
|
+
|
|
53
|
+
child.stdout.on("data", (chunk) => {
|
|
54
|
+
stdout += String(chunk);
|
|
55
|
+
});
|
|
56
|
+
|
|
57
|
+
child.stderr.on("data", (chunk) => {
|
|
58
|
+
stderr += String(chunk);
|
|
59
|
+
});
|
|
60
|
+
|
|
61
|
+
child.on("error", (error) => {
|
|
62
|
+
clearTimeout(timer);
|
|
63
|
+
reject(error);
|
|
64
|
+
});
|
|
65
|
+
|
|
66
|
+
child.on("close", (code) => {
|
|
67
|
+
clearTimeout(timer);
|
|
68
|
+
if (code !== 0) {
|
|
69
|
+
reject(new Error(`host_bridge_command_failed:${code}:${stderr}`));
|
|
70
|
+
return;
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
try {
|
|
74
|
+
const parsed = JSON.parse(stdout) as HostBridgeResponse;
|
|
75
|
+
resolve(parsed);
|
|
76
|
+
} catch {
|
|
77
|
+
resolve({ ok: true, output: stdout.trim(), model: "host-bridge-command" });
|
|
78
|
+
}
|
|
79
|
+
});
|
|
80
|
+
|
|
81
|
+
child.stdin.write(JSON.stringify(payload));
|
|
82
|
+
child.stdin.end();
|
|
83
|
+
});
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
function getGlobalBridge(): HostBridgeFn | undefined {
|
|
87
|
+
const candidate = (globalThis as Record<string, unknown>).__RAX_HOST_BRIDGE__;
|
|
88
|
+
if (typeof candidate === "function") {
|
|
89
|
+
return candidate as HostBridgeFn;
|
|
90
|
+
}
|
|
91
|
+
return undefined;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
export class HostBridgeAdapter implements IModelProvider {
|
|
95
|
+
private readonly model: string;
|
|
96
|
+
private readonly mode: "auto" | "bridge-only" | "mock";
|
|
97
|
+
private readonly command?: string;
|
|
98
|
+
private readonly timeoutMs: number;
|
|
99
|
+
|
|
100
|
+
constructor(options: HostBridgeAdapterOptions = {}) {
|
|
101
|
+
this.model = options.model ?? "host-managed";
|
|
102
|
+
this.mode = options.mode ?? "auto";
|
|
103
|
+
this.command = options.command;
|
|
104
|
+
this.timeoutMs = options.timeoutMs ?? 20000;
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
private async callBridge(payload: HostBridgeRequest): Promise<HostBridgeResponse> {
|
|
108
|
+
const bridgeFn = getGlobalBridge();
|
|
109
|
+
if (bridgeFn) {
|
|
110
|
+
return await bridgeFn(payload);
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
if (this.command) {
|
|
114
|
+
return await runBridgeCommand(this.command, payload, this.timeoutMs);
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
if (this.mode === "bridge-only") {
|
|
118
|
+
throw new Error("host_bridge_missing");
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
return {
|
|
122
|
+
ok: true,
|
|
123
|
+
model: this.model,
|
|
124
|
+
latencyMs: 1,
|
|
125
|
+
output:
|
|
126
|
+
payload.action === "callStructured"
|
|
127
|
+
? {
|
|
128
|
+
agent: "HostBridgeMock",
|
|
129
|
+
success: true,
|
|
130
|
+
confidence: 0.7,
|
|
131
|
+
risks: ["latency"],
|
|
132
|
+
logs: ["host bridge unavailable, using local mock response"],
|
|
133
|
+
data: {
|
|
134
|
+
summary: "Mock structured output from host adapter",
|
|
135
|
+
nextAction: "configure_host_bridge"
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
: `host-mock:${payload.prompt.slice(0, 240)}`
|
|
139
|
+
};
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
async callModel(prompt: string, options?: ProviderCallOptions): Promise<ModelResponse<string>> {
|
|
143
|
+
const started = Date.now();
|
|
144
|
+
const res = await this.callBridge({ version: 1, action: "callModel", prompt, options });
|
|
145
|
+
if (!res.ok) {
|
|
146
|
+
throw new Error(`host_call_failed:${res.error ?? "unknown"}`);
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
return {
|
|
150
|
+
provider: "host",
|
|
151
|
+
model: res.model ?? options?.model ?? this.model,
|
|
152
|
+
latencyMs: res.latencyMs ?? Date.now() - started,
|
|
153
|
+
output: typeof res.output === "string" ? res.output : JSON.stringify(res.output ?? "")
|
|
154
|
+
};
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
async callStructured<T>(prompt: string, schema: object, options?: ProviderCallOptions): Promise<ModelResponse<T>> {
|
|
158
|
+
const started = Date.now();
|
|
159
|
+
const res = await this.callBridge({ version: 1, action: "callStructured", prompt, schema, options });
|
|
160
|
+
if (!res.ok) {
|
|
161
|
+
throw new Error(`host_structured_failed:${res.error ?? "unknown"}`);
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
let output: T | null = null;
|
|
165
|
+
if (res.output && typeof res.output === "object") {
|
|
166
|
+
output = res.output as T;
|
|
167
|
+
} else if (typeof res.output === "string") {
|
|
168
|
+
output = parseJsonObjectFromText<T>(res.output);
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
if (!output) {
|
|
172
|
+
throw new Error("host_structured_parse_failed");
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
return {
|
|
176
|
+
provider: "host",
|
|
177
|
+
model: res.model ?? options?.model ?? this.model,
|
|
178
|
+
latencyMs: res.latencyMs ?? Date.now() - started,
|
|
179
|
+
output
|
|
180
|
+
};
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
async healthCheck(): Promise<boolean> {
|
|
184
|
+
if (this.mode === "mock") return true;
|
|
185
|
+
if (getGlobalBridge()) return true;
|
|
186
|
+
if (this.command) return true;
|
|
187
|
+
return this.mode === "auto";
|
|
188
|
+
}
|
|
189
|
+
}
|
package/src/index.ts
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
export * from "./error-mapper.js";
|
|
2
|
+
export * from "./openai-adapter.js";
|
|
3
|
+
export * from "./claude-adapter.js";
|
|
4
|
+
export * from "./gemini-adapter.js";
|
|
5
|
+
export * from "./groq-adapter.js";
|
|
6
|
+
export * from "./mistral-adapter.js";
|
|
7
|
+
export * from "./cohere-adapter.js";
|
|
8
|
+
export * from "./rest-adapter.js";
|
|
9
|
+
export * from "./host-bridge-adapter.js";
|
|
10
|
+
export * from "./strategy.js";
|
|
11
|
+
export * from "./utils.js";
|
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file mistral-adapter.ts
|
|
3
|
+
* Mistral AI bridge adapter.
|
|
4
|
+
*
|
|
5
|
+
* API reference: https://docs.mistral.ai/api/
|
|
6
|
+
*
|
|
7
|
+
* • callModel → POST /v1/chat/completions (standard)
|
|
8
|
+
* • callStructured → same endpoint with `response_format: { type: "json_object" }`
|
|
9
|
+
* Mistral supports native JSON mode since 2024-02 on all
|
|
10
|
+
* large models. For smaller models we fall back to
|
|
11
|
+
* system-prompt enforcement.
|
|
12
|
+
*
|
|
13
|
+
* Supported models: mistral-large-latest, mistral-small-latest,
|
|
14
|
+
* open-mixtral-8x22b, codestral-latest, etc.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
import { IModelProvider, ModelResponse, ProviderCallOptions } from "@rax-flow/core";
|
|
18
|
+
import { parseJsonObjectFromText } from "./utils.js";
|
|
19
|
+
import {
|
|
20
|
+
RaxProviderError,
|
|
21
|
+
mapHttpError,
|
|
22
|
+
mapNetworkError,
|
|
23
|
+
mapParseError,
|
|
24
|
+
} from "./error-mapper.js";
|
|
25
|
+
import { calculateCost } from "./pricing.js";
|
|
26
|
+
|
|
27
|
+
// ─────────────────────────────────────────────────────────────────────────────
|
|
28
|
+
// Wire shapes
|
|
29
|
+
// ─────────────────────────────────────────────────────────────────────────────
|
|
30
|
+
|
|
31
|
+
interface MistralMessage {
|
|
32
|
+
role: "system" | "user" | "assistant";
|
|
33
|
+
content: string;
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
interface MistralChatRequest {
|
|
37
|
+
model: string;
|
|
38
|
+
messages: MistralMessage[];
|
|
39
|
+
temperature?: number;
|
|
40
|
+
max_tokens?: number;
|
|
41
|
+
response_format?: { type: "text" | "json_object" };
|
|
42
|
+
safe_prompt?: boolean;
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
interface MistralChatResponse {
|
|
46
|
+
id?: string;
|
|
47
|
+
choices?: Array<{
|
|
48
|
+
message?: { content?: string };
|
|
49
|
+
finish_reason?: string;
|
|
50
|
+
}>;
|
|
51
|
+
usage?: { total_tokens?: number; prompt_tokens?: number; completion_tokens?: number };
|
|
52
|
+
// Error body when status is non-2xx
|
|
53
|
+
message?: string;
|
|
54
|
+
detail?: string;
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
// ─────────────────────────────────────────────────────────────────────────────
|
|
58
|
+
// Adapter
|
|
59
|
+
// ─────────────────────────────────────────────────────────────────────────────
|
|
60
|
+
|
|
61
|
+
const VENDOR = "mistral";
|
|
62
|
+
const DEFAULT_MODEL = "mistral-large-latest";
|
|
63
|
+
const BASE_URL = "https://api.mistral.ai";
|
|
64
|
+
|
|
65
|
+
/** Models known to support native json_object response_format. */
|
|
66
|
+
const JSON_MODE_MODELS = new Set([
|
|
67
|
+
"mistral-large-latest",
|
|
68
|
+
"mistral-small-latest",
|
|
69
|
+
"codestral-latest",
|
|
70
|
+
"open-mixtral-8x22b",
|
|
71
|
+
"open-mistral-nemo",
|
|
72
|
+
]);
|
|
73
|
+
|
|
74
|
+
export interface MistralAdapterOptions {
|
|
75
|
+
apiKey: string;
|
|
76
|
+
baseUrl?: string;
|
|
77
|
+
defaultModel?: string;
|
|
78
|
+
timeoutMs?: number;
|
|
79
|
+
/** Disable Mistral safe-prompt for structured tasks (default: false). */
|
|
80
|
+
safePrompt?: boolean;
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
export class MistralAdapter implements IModelProvider {
|
|
84
|
+
private readonly apiKey: string;
|
|
85
|
+
private readonly baseUrl: string;
|
|
86
|
+
private readonly defaultModel: string;
|
|
87
|
+
private readonly timeoutMs: number;
|
|
88
|
+
private readonly safePrompt: boolean;
|
|
89
|
+
|
|
90
|
+
constructor(options: MistralAdapterOptions) {
|
|
91
|
+
this.apiKey = options.apiKey;
|
|
92
|
+
this.baseUrl = options.baseUrl ?? BASE_URL;
|
|
93
|
+
this.defaultModel = options.defaultModel ?? DEFAULT_MODEL;
|
|
94
|
+
this.timeoutMs = options.timeoutMs ?? 25_000;
|
|
95
|
+
this.safePrompt = options.safePrompt ?? false;
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
// ── private helpers ────────────────────────────────────────────────────────
|
|
99
|
+
|
|
100
|
+
private supportsJsonMode(model: string): boolean {
|
|
101
|
+
return JSON_MODE_MODELS.has(model);
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
private async post(
|
|
105
|
+
body: MistralChatRequest,
|
|
106
|
+
context: "callModel" | "callStructured"
|
|
107
|
+
): Promise<MistralChatResponse> {
|
|
108
|
+
const controller = new AbortController();
|
|
109
|
+
const timer = setTimeout(() => controller.abort(), this.timeoutMs);
|
|
110
|
+
|
|
111
|
+
let res: Response;
|
|
112
|
+
try {
|
|
113
|
+
res = await fetch(`${this.baseUrl}/v1/chat/completions`, {
|
|
114
|
+
method: "POST",
|
|
115
|
+
headers: {
|
|
116
|
+
Authorization: `Bearer ${this.apiKey}`,
|
|
117
|
+
"Content-Type": "application/json",
|
|
118
|
+
Accept: "application/json",
|
|
119
|
+
},
|
|
120
|
+
body: JSON.stringify(body),
|
|
121
|
+
signal: controller.signal,
|
|
122
|
+
});
|
|
123
|
+
} catch (err) {
|
|
124
|
+
clearTimeout(timer);
|
|
125
|
+
throw mapNetworkError(VENDOR, err);
|
|
126
|
+
} finally {
|
|
127
|
+
clearTimeout(timer);
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
const raw = await res.json().catch(() => ({}));
|
|
131
|
+
|
|
132
|
+
if (!res.ok) {
|
|
133
|
+
throw mapHttpError(VENDOR, res.status, raw, context);
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
const payload = raw as MistralChatResponse;
|
|
137
|
+
|
|
138
|
+
// Mistral can signal content policy stop
|
|
139
|
+
const finishReason = payload.choices?.[0]?.finish_reason;
|
|
140
|
+
if (finishReason === "content_filter") {
|
|
141
|
+
throw new RaxProviderError(
|
|
142
|
+
VENDOR,
|
|
143
|
+
"content_filtered",
|
|
144
|
+
"Mistral content policy rejection",
|
|
145
|
+
{ raw: payload }
|
|
146
|
+
);
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
return payload;
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
private extractText(payload: MistralChatResponse): string {
|
|
153
|
+
return payload.choices?.[0]?.message?.content ?? "";
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
// ── IModelProvider ─────────────────────────────────────────────────────────
|
|
157
|
+
|
|
158
|
+
async callModel(
|
|
159
|
+
prompt: string,
|
|
160
|
+
options?: ProviderCallOptions
|
|
161
|
+
): Promise<ModelResponse<string>> {
|
|
162
|
+
const started = Date.now();
|
|
163
|
+
const model = options?.model ?? this.defaultModel;
|
|
164
|
+
|
|
165
|
+
const payload = await this.post(
|
|
166
|
+
{
|
|
167
|
+
model,
|
|
168
|
+
temperature: options?.temperature ?? 0.2,
|
|
169
|
+
max_tokens: options?.maxTokens ?? 1200,
|
|
170
|
+
safe_prompt: this.safePrompt,
|
|
171
|
+
messages: [{ role: "user", content: prompt }],
|
|
172
|
+
},
|
|
173
|
+
"callModel"
|
|
174
|
+
);
|
|
175
|
+
|
|
176
|
+
const usage = payload.usage
|
|
177
|
+
? {
|
|
178
|
+
promptTokens: payload.usage.prompt_tokens ?? 0,
|
|
179
|
+
completionTokens: payload.usage.completion_tokens ?? 0,
|
|
180
|
+
totalTokens: payload.usage.total_tokens ?? 0,
|
|
181
|
+
}
|
|
182
|
+
: undefined;
|
|
183
|
+
|
|
184
|
+
return {
|
|
185
|
+
provider: VENDOR,
|
|
186
|
+
model,
|
|
187
|
+
latencyMs: Date.now() - started,
|
|
188
|
+
costUsd: calculateCost(model, usage),
|
|
189
|
+
usage,
|
|
190
|
+
output: this.extractText(payload),
|
|
191
|
+
raw: payload,
|
|
192
|
+
};
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
async callStructured<T>(
|
|
196
|
+
prompt: string,
|
|
197
|
+
schema: object,
|
|
198
|
+
options?: ProviderCallOptions
|
|
199
|
+
): Promise<ModelResponse<T>> {
|
|
200
|
+
const started = Date.now();
|
|
201
|
+
const model = options?.model ?? this.defaultModel;
|
|
202
|
+
|
|
203
|
+
const messages: MistralMessage[] = [];
|
|
204
|
+
|
|
205
|
+
if (this.supportsJsonMode(model)) {
|
|
206
|
+
// Native JSON mode — schema hint via system prompt
|
|
207
|
+
messages.push({
|
|
208
|
+
role: "system",
|
|
209
|
+
content: `Respond with ONLY a valid JSON object that matches this schema:\n${JSON.stringify(schema, null, 2)}`,
|
|
210
|
+
});
|
|
211
|
+
messages.push({ role: "user", content: prompt });
|
|
212
|
+
} else {
|
|
213
|
+
// Fallback for non-json-mode models: inline schema in user message
|
|
214
|
+
messages.push({
|
|
215
|
+
role: "user",
|
|
216
|
+
content: `${prompt}\n\nRespond ONLY with a JSON object matching this schema:\n${JSON.stringify(schema)}`,
|
|
217
|
+
});
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
const requestBody: MistralChatRequest = {
|
|
221
|
+
model,
|
|
222
|
+
temperature: options?.temperature ?? 0,
|
|
223
|
+
max_tokens: options?.maxTokens ?? 1400,
|
|
224
|
+
safe_prompt: false, // disable safe rewrites on structured tasks
|
|
225
|
+
messages,
|
|
226
|
+
};
|
|
227
|
+
|
|
228
|
+
if (this.supportsJsonMode(model)) {
|
|
229
|
+
requestBody.response_format = { type: "json_object" };
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
const payload = await this.post(requestBody, "callStructured");
|
|
233
|
+
const text = this.extractText(payload);
|
|
234
|
+
const parsed = parseJsonObjectFromText<T>(text);
|
|
235
|
+
|
|
236
|
+
if (!parsed) {
|
|
237
|
+
throw mapParseError(VENDOR, "callStructured", text);
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
const usage = payload.usage
|
|
241
|
+
? {
|
|
242
|
+
promptTokens: payload.usage.prompt_tokens ?? 0,
|
|
243
|
+
completionTokens: payload.usage.completion_tokens ?? 0,
|
|
244
|
+
totalTokens: payload.usage.total_tokens ?? 0,
|
|
245
|
+
}
|
|
246
|
+
: undefined;
|
|
247
|
+
|
|
248
|
+
return {
|
|
249
|
+
provider: VENDOR,
|
|
250
|
+
model,
|
|
251
|
+
latencyMs: Date.now() - started,
|
|
252
|
+
costUsd: calculateCost(model, usage),
|
|
253
|
+
usage,
|
|
254
|
+
output: parsed,
|
|
255
|
+
raw: payload,
|
|
256
|
+
};
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
async healthCheck(): Promise<boolean> {
|
|
260
|
+
return Boolean(this.apiKey && this.baseUrl);
|
|
261
|
+
}
|
|
262
|
+
}
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file openai-adapter.ts
|
|
3
|
+
* OpenAI bridge adapter.
|
|
4
|
+
*
|
|
5
|
+
* API reference: https://platform.openai.com/docs/api-reference/chat
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
import { IModelProvider, IEmbeddingProvider, ModelResponse, ProviderCallOptions, EmbeddingResponse } from "@rax-flow/core";
|
|
9
|
+
import { asString, parseJsonObjectFromText } from "./utils.js";
|
|
10
|
+
import { calculateCost } from "./pricing.js";
|
|
11
|
+
import {
|
|
12
|
+
RaxProviderError,
|
|
13
|
+
mapHttpError,
|
|
14
|
+
mapNetworkError,
|
|
15
|
+
mapParseError,
|
|
16
|
+
} from "./error-mapper.js";
|
|
17
|
+
|
|
18
|
+
type ContentPart = { type: string; text?: string };
|
|
19
|
+
|
|
20
|
+
interface OpenAIMessage {
|
|
21
|
+
role: "system" | "user" | "assistant";
|
|
22
|
+
content: string | ContentPart[];
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
interface OpenAIChatRequest {
|
|
26
|
+
model: string;
|
|
27
|
+
messages: OpenAIMessage[];
|
|
28
|
+
temperature?: number;
|
|
29
|
+
max_tokens?: number;
|
|
30
|
+
response_format?:
|
|
31
|
+
| { type: "text" }
|
|
32
|
+
| { type: "json_object" }
|
|
33
|
+
| {
|
|
34
|
+
type: "json_schema";
|
|
35
|
+
json_schema: { name: string; strict: boolean; schema: object };
|
|
36
|
+
};
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
interface OpenAIChatResponse {
|
|
40
|
+
id?: string;
|
|
41
|
+
choices?: Array<{
|
|
42
|
+
message?: { content?: string | ContentPart[] };
|
|
43
|
+
finish_reason?: string;
|
|
44
|
+
}>;
|
|
45
|
+
usage?: { total_tokens?: number; prompt_tokens?: number; completion_tokens?: number };
|
|
46
|
+
error?: { message?: string; type?: string; code?: string };
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
const VENDOR = "openai";
|
|
50
|
+
const DEFAULT_MODEL = "gpt-4.1-mini";
|
|
51
|
+
const DEFAULT_BASE_URL = "https://api.openai.com/v1";
|
|
52
|
+
|
|
53
|
+
const STRICT_JSON_SCHEMA_MODELS = new Set([
|
|
54
|
+
"gpt-4o",
|
|
55
|
+
"gpt-4o-mini",
|
|
56
|
+
"gpt-4.1",
|
|
57
|
+
"gpt-4.1-mini",
|
|
58
|
+
"gpt-4.1-nano",
|
|
59
|
+
"gpt-4-turbo",
|
|
60
|
+
"o1",
|
|
61
|
+
"o1-mini",
|
|
62
|
+
"o3-mini",
|
|
63
|
+
]);
|
|
64
|
+
|
|
65
|
+
export interface OpenAIAdapterOptions {
|
|
66
|
+
apiKey: string;
|
|
67
|
+
baseUrl?: string;
|
|
68
|
+
defaultModel?: string;
|
|
69
|
+
timeoutMs?: number;
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
export class OpenAIAdapter implements IModelProvider, IEmbeddingProvider {
|
|
73
|
+
private readonly apiKey: string;
|
|
74
|
+
private readonly baseUrl: string;
|
|
75
|
+
private readonly defaultModel: string;
|
|
76
|
+
private readonly timeoutMs: number;
|
|
77
|
+
|
|
78
|
+
constructor(options: OpenAIAdapterOptions);
|
|
79
|
+
/** @deprecated Pass an options object instead of positional args. */
|
|
80
|
+
constructor(apiKey: string, baseUrl?: string);
|
|
81
|
+
constructor(
|
|
82
|
+
optionsOrApiKey: OpenAIAdapterOptions | string,
|
|
83
|
+
legacyBaseUrl?: string
|
|
84
|
+
) {
|
|
85
|
+
if (typeof optionsOrApiKey === "string") {
|
|
86
|
+
this.apiKey = optionsOrApiKey;
|
|
87
|
+
this.baseUrl = legacyBaseUrl ?? DEFAULT_BASE_URL;
|
|
88
|
+
this.defaultModel = DEFAULT_MODEL;
|
|
89
|
+
this.timeoutMs = 30_000;
|
|
90
|
+
} else {
|
|
91
|
+
this.apiKey = optionsOrApiKey.apiKey;
|
|
92
|
+
this.baseUrl = optionsOrApiKey.baseUrl ?? DEFAULT_BASE_URL;
|
|
93
|
+
this.defaultModel = optionsOrApiKey.defaultModel ?? DEFAULT_MODEL;
|
|
94
|
+
this.timeoutMs = optionsOrApiKey.timeoutMs ?? 30_000;
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
private supportsStrictJsonSchema(model: string): boolean {
|
|
99
|
+
for (const known of STRICT_JSON_SCHEMA_MODELS) {
|
|
100
|
+
if (model === known || model.startsWith(`${known}-`)) return true;
|
|
101
|
+
}
|
|
102
|
+
return false;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
private async post(
|
|
106
|
+
body: OpenAIChatRequest,
|
|
107
|
+
context: "callModel" | "callStructured"
|
|
108
|
+
): Promise<OpenAIChatResponse> {
|
|
109
|
+
const controller = new AbortController();
|
|
110
|
+
const timer = setTimeout(() => controller.abort(), this.timeoutMs);
|
|
111
|
+
|
|
112
|
+
let res: Response;
|
|
113
|
+
try {
|
|
114
|
+
res = await fetch(`${this.baseUrl}/chat/completions`, {
|
|
115
|
+
method: "POST",
|
|
116
|
+
headers: {
|
|
117
|
+
Authorization: `Bearer ${this.apiKey}`,
|
|
118
|
+
"Content-Type": "application/json",
|
|
119
|
+
},
|
|
120
|
+
body: JSON.stringify(body),
|
|
121
|
+
signal: controller.signal,
|
|
122
|
+
});
|
|
123
|
+
} catch (err) {
|
|
124
|
+
clearTimeout(timer);
|
|
125
|
+
throw mapNetworkError(VENDOR, err);
|
|
126
|
+
} finally {
|
|
127
|
+
clearTimeout(timer);
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
const raw = await res.json().catch(() => ({}));
|
|
131
|
+
if (!res.ok) throw mapHttpError(VENDOR, res.status, raw, context);
|
|
132
|
+
|
|
133
|
+
const payload = raw as OpenAIChatResponse;
|
|
134
|
+
if (payload.error) {
|
|
135
|
+
const errMsg = payload.error.message ?? "unknown openai error";
|
|
136
|
+
const isRefusal = payload.error.type === "invalid_request_error" && /content_policy|moderat|filter/i.test(errMsg);
|
|
137
|
+
throw new RaxProviderError(VENDOR, isRefusal ? "content_filtered" : "invalid_request", errMsg, { raw: payload });
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
const finishReason = payload.choices?.[0]?.finish_reason;
|
|
141
|
+
if (finishReason === "content_filter") {
|
|
142
|
+
throw new RaxProviderError(VENDOR, "content_filtered", "OpenAI content policy rejection", { raw: payload });
|
|
143
|
+
}
|
|
144
|
+
return payload;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
private extractText(payload: OpenAIChatResponse): string {
|
|
148
|
+
const content = payload.choices?.[0]?.message?.content;
|
|
149
|
+
if (Array.isArray(content)) return content.map((c) => c.text ?? "").join("\n");
|
|
150
|
+
return content ?? "";
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
async callModel(prompt: string, options?: ProviderCallOptions): Promise<ModelResponse<string>> {
|
|
154
|
+
const started = Date.now();
|
|
155
|
+
const model = options?.model ?? this.defaultModel;
|
|
156
|
+
const payload = await this.post({
|
|
157
|
+
model,
|
|
158
|
+
temperature: options?.temperature ?? 0.2,
|
|
159
|
+
max_tokens: options?.maxTokens ?? 1200,
|
|
160
|
+
messages: [{ role: "user", content: prompt }],
|
|
161
|
+
}, "callModel");
|
|
162
|
+
|
|
163
|
+
const usage = payload.usage ? {
|
|
164
|
+
promptTokens: payload.usage.prompt_tokens ?? 0,
|
|
165
|
+
completionTokens: payload.usage.completion_tokens ?? 0,
|
|
166
|
+
totalTokens: payload.usage.total_tokens ?? 0,
|
|
167
|
+
} : undefined;
|
|
168
|
+
|
|
169
|
+
return {
|
|
170
|
+
provider: VENDOR, model, latencyMs: Date.now() - started, costUsd: calculateCost(model, usage), usage,
|
|
171
|
+
output: this.extractText(payload), raw: payload,
|
|
172
|
+
};
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
async callStructured<T>(prompt: string, schema: object, options?: ProviderCallOptions): Promise<ModelResponse<T>> {
|
|
176
|
+
const started = Date.now();
|
|
177
|
+
const model = options?.model ?? this.defaultModel;
|
|
178
|
+
|
|
179
|
+
const body: OpenAIChatRequest = {
|
|
180
|
+
model, temperature: options?.temperature ?? 0, max_tokens: options?.maxTokens ?? 1400,
|
|
181
|
+
messages: [{ role: "user", content: prompt }],
|
|
182
|
+
};
|
|
183
|
+
|
|
184
|
+
if (this.supportsStrictJsonSchema(model)) {
|
|
185
|
+
body.response_format = { type: "json_schema", json_schema: { name: "rax_flow_output", strict: true, schema } };
|
|
186
|
+
} else {
|
|
187
|
+
body.response_format = { type: "json_object" };
|
|
188
|
+
body.messages = [
|
|
189
|
+
{ role: "system", content: `Respond ONLY with a valid JSON object matching this schema:\n${JSON.stringify(schema, null, 2)}` },
|
|
190
|
+
{ role: "user", content: prompt },
|
|
191
|
+
];
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
const payload = await this.post(body, "callStructured");
|
|
195
|
+
const text = asString(this.extractText(payload) || "{}");
|
|
196
|
+
const parsed = parseJsonObjectFromText<T>(text);
|
|
197
|
+
if (!parsed) throw mapParseError(VENDOR, "callStructured", text);
|
|
198
|
+
|
|
199
|
+
const usage = payload.usage ? {
|
|
200
|
+
promptTokens: payload.usage.prompt_tokens ?? 0,
|
|
201
|
+
completionTokens: payload.usage.completion_tokens ?? 0,
|
|
202
|
+
totalTokens: payload.usage.total_tokens ?? 0,
|
|
203
|
+
} : undefined;
|
|
204
|
+
|
|
205
|
+
return {
|
|
206
|
+
provider: VENDOR, model, latencyMs: Date.now() - started, costUsd: calculateCost(model, usage), usage,
|
|
207
|
+
output: parsed, raw: payload,
|
|
208
|
+
};
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
async embed(texts: string[]): Promise<EmbeddingResponse> {
|
|
212
|
+
const res = await fetch(`${this.baseUrl}/embeddings`, {
|
|
213
|
+
method: "POST",
|
|
214
|
+
headers: {
|
|
215
|
+
Authorization: `Bearer ${this.apiKey}`,
|
|
216
|
+
"Content-Type": "application/json",
|
|
217
|
+
},
|
|
218
|
+
body: JSON.stringify({
|
|
219
|
+
input: texts,
|
|
220
|
+
model: "text-embedding-3-small",
|
|
221
|
+
}),
|
|
222
|
+
});
|
|
223
|
+
|
|
224
|
+
if (!res.ok) {
|
|
225
|
+
const raw = await res.json().catch(() => ({}));
|
|
226
|
+
throw mapHttpError(VENDOR, res.status, raw, "embed" as any);
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
const payload = (await res.json()) as { data: Array<{ embedding: number[] }>; model: string; usage: { total_tokens: number } };
|
|
230
|
+
return {
|
|
231
|
+
vectors: payload.data.map((d) => d.embedding),
|
|
232
|
+
model: payload.model,
|
|
233
|
+
usage: { totalTokens: payload.usage.total_tokens }
|
|
234
|
+
};
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
async healthCheck(): Promise<boolean> {
|
|
238
|
+
return Boolean(this.apiKey && this.baseUrl);
|
|
239
|
+
}
|
|
240
|
+
}
|