@hebo-ai/gateway 0.5.0 → 0.5.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +2 -2
- package/dist/config.js +2 -0
- package/dist/endpoints/chat-completions/converters.d.ts +5 -1
- package/dist/endpoints/chat-completions/converters.js +61 -12
- package/dist/endpoints/chat-completions/schema.d.ts +54 -9
- package/dist/endpoints/chat-completions/schema.js +20 -13
- package/dist/models/anthropic/middleware.js +14 -13
- package/dist/providers/bedrock/middleware.d.ts +2 -1
- package/dist/providers/bedrock/middleware.js +29 -9
- package/dist/telemetry/ai-sdk.d.ts +2 -0
- package/dist/telemetry/ai-sdk.js +31 -0
- package/package.json +1 -1
- package/src/config.ts +3 -0
- package/src/endpoints/chat-completions/converters.test.ts +111 -0
- package/src/endpoints/chat-completions/converters.ts +71 -13
- package/src/endpoints/chat-completions/handler.ts +10 -3
- package/src/endpoints/chat-completions/schema.ts +22 -14
- package/src/endpoints/embeddings/handler.ts +5 -3
- package/src/middleware/debug.ts +37 -0
- package/src/middleware/matcher.ts +4 -0
- package/src/models/anthropic/middleware.test.ts +5 -1
- package/src/models/anthropic/middleware.ts +17 -13
- package/src/providers/bedrock/middleware.test.ts +118 -8
- package/src/providers/bedrock/middleware.ts +34 -9
- package/src/telemetry/ai-sdk.ts +46 -0
|
@@ -56,6 +56,7 @@ export type TextCallOptions = {
|
|
|
56
56
|
messages: ModelMessage[];
|
|
57
57
|
tools?: ToolSet;
|
|
58
58
|
toolChoice?: ToolChoice<ToolSet>;
|
|
59
|
+
activeTools?: Array<keyof ToolSet>;
|
|
59
60
|
output?: Output.Output;
|
|
60
61
|
temperature?: number;
|
|
61
62
|
maxOutputTokens?: number;
|
|
@@ -90,10 +91,13 @@ export function convertToTextCallOptions(params: ChatCompletionsInputs): TextCal
|
|
|
90
91
|
|
|
91
92
|
Object.assign(rest, parseReasoningOptions(reasoning_effort, reasoning));
|
|
92
93
|
|
|
94
|
+
const { toolChoice, activeTools } = convertToToolChoiceOptions(tool_choice);
|
|
95
|
+
|
|
93
96
|
return {
|
|
94
97
|
messages: convertToModelMessages(messages),
|
|
95
98
|
tools: convertToToolSet(tools),
|
|
96
|
-
toolChoice
|
|
99
|
+
toolChoice,
|
|
100
|
+
activeTools,
|
|
97
101
|
output: convertToOutput(response_format),
|
|
98
102
|
temperature,
|
|
99
103
|
maxOutputTokens: max_completion_tokens ?? max_tokens,
|
|
@@ -321,30 +325,43 @@ export const convertToToolSet = (tools: ChatCompletionsTool[] | undefined): Tool
|
|
|
321
325
|
toolSet[t.function.name] = tool({
|
|
322
326
|
description: t.function.description,
|
|
323
327
|
inputSchema: jsonSchema(t.function.parameters),
|
|
328
|
+
strict: t.function.strict,
|
|
324
329
|
});
|
|
325
330
|
}
|
|
326
331
|
return toolSet;
|
|
327
332
|
};
|
|
328
333
|
|
|
329
|
-
export const
|
|
334
|
+
export const convertToToolChoiceOptions = (
|
|
330
335
|
toolChoice: ChatCompletionsToolChoice | undefined,
|
|
331
|
-
):
|
|
336
|
+
): {
|
|
337
|
+
toolChoice?: ToolChoice<ToolSet>;
|
|
338
|
+
activeTools?: Array<keyof ToolSet>;
|
|
339
|
+
} => {
|
|
332
340
|
if (!toolChoice) {
|
|
333
|
-
return
|
|
341
|
+
return {};
|
|
334
342
|
}
|
|
335
343
|
|
|
336
344
|
if (toolChoice === "none" || toolChoice === "auto" || toolChoice === "required") {
|
|
337
|
-
return toolChoice;
|
|
345
|
+
return { toolChoice };
|
|
338
346
|
}
|
|
339
347
|
|
|
340
348
|
// FUTURE: this is right now google specific, which is not supported by AI SDK, until then, we temporarily map it to auto for now https://docs.cloud.google.com/vertex-ai/generative-ai/docs/migrate/openai/overview
|
|
341
349
|
if (toolChoice === "validated") {
|
|
342
|
-
return "auto";
|
|
350
|
+
return { toolChoice: "auto" };
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
if (toolChoice.type === "allowed_tools") {
|
|
354
|
+
return {
|
|
355
|
+
toolChoice: toolChoice.allowed_tools.mode,
|
|
356
|
+
activeTools: toolChoice.allowed_tools.tools.map((toolRef) => toolRef.function.name),
|
|
357
|
+
};
|
|
343
358
|
}
|
|
344
359
|
|
|
345
360
|
return {
|
|
346
|
-
|
|
347
|
-
|
|
361
|
+
toolChoice: {
|
|
362
|
+
type: "tool",
|
|
363
|
+
toolName: toolChoice.function.name,
|
|
364
|
+
},
|
|
348
365
|
};
|
|
349
366
|
};
|
|
350
367
|
|
|
@@ -617,9 +634,11 @@ export const toChatCompletionsAssistantMessage = (
|
|
|
617
634
|
if (part.type === "text") {
|
|
618
635
|
if (message.content === null) {
|
|
619
636
|
message.content = part.text;
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
637
|
+
} else {
|
|
638
|
+
message.content += part.text;
|
|
639
|
+
}
|
|
640
|
+
if (part.providerMetadata) {
|
|
641
|
+
message.extra_content = part.providerMetadata;
|
|
623
642
|
}
|
|
624
643
|
} else if (part.type === "reasoning") {
|
|
625
644
|
reasoningDetails.push(
|
|
@@ -646,6 +665,11 @@ export const toChatCompletionsAssistantMessage = (
|
|
|
646
665
|
message.reasoning_details = reasoningDetails;
|
|
647
666
|
}
|
|
648
667
|
|
|
668
|
+
if (!message.content && !message.tool_calls) {
|
|
669
|
+
// some models return just reasoning without tool calls or content
|
|
670
|
+
message.content = "";
|
|
671
|
+
}
|
|
672
|
+
|
|
649
673
|
return message;
|
|
650
674
|
};
|
|
651
675
|
|
|
@@ -722,8 +746,8 @@ export function toChatCompletionsToolCall(
|
|
|
722
746
|
id,
|
|
723
747
|
type: "function",
|
|
724
748
|
function: {
|
|
725
|
-
name,
|
|
726
|
-
arguments: typeof args === "string" ? args : JSON.stringify(args),
|
|
749
|
+
name: normalizeToolName(name),
|
|
750
|
+
arguments: typeof args === "string" ? args : JSON.stringify(stripEmptyKeys(args)),
|
|
727
751
|
},
|
|
728
752
|
};
|
|
729
753
|
|
|
@@ -734,6 +758,40 @@ export function toChatCompletionsToolCall(
|
|
|
734
758
|
return out;
|
|
735
759
|
}
|
|
736
760
|
|
|
761
|
+
function normalizeToolName(name: string): string {
|
|
762
|
+
// some models hallucinate invalid characters
|
|
763
|
+
// normalize to valid characters [^A-Za-z0-9_-.] (non regex for perf)
|
|
764
|
+
// https://modelcontextprotocol.io/specification/draft/server/tools#tool-names
|
|
765
|
+
let out = "";
|
|
766
|
+
for (let i = 0; i < name.length; i++) {
|
|
767
|
+
if (out.length === 128) break;
|
|
768
|
+
|
|
769
|
+
// eslint-disable-next-line unicorn/prefer-code-point
|
|
770
|
+
const c = name.charCodeAt(i);
|
|
771
|
+
|
|
772
|
+
if (
|
|
773
|
+
(c >= 48 && c <= 57) ||
|
|
774
|
+
(c >= 65 && c <= 90) ||
|
|
775
|
+
(c >= 97 && c <= 122) ||
|
|
776
|
+
c === 95 ||
|
|
777
|
+
c === 45 ||
|
|
778
|
+
c === 46
|
|
779
|
+
) {
|
|
780
|
+
out += name[i];
|
|
781
|
+
} else {
|
|
782
|
+
out += "_";
|
|
783
|
+
}
|
|
784
|
+
}
|
|
785
|
+
return out;
|
|
786
|
+
}
|
|
787
|
+
|
|
788
|
+
function stripEmptyKeys(obj: unknown) {
|
|
789
|
+
if (!obj || typeof obj !== "object" || Array.isArray(obj)) return obj;
|
|
790
|
+
// some models hallucinate empty parameters
|
|
791
|
+
delete (obj as Record<string, unknown>)[""];
|
|
792
|
+
return obj;
|
|
793
|
+
}
|
|
794
|
+
|
|
737
795
|
export const toChatCompletionsFinishReason = (
|
|
738
796
|
finishReason: FinishReason,
|
|
739
797
|
): ChatCompletionsFinishReason => {
|
|
@@ -36,7 +36,7 @@ import {
|
|
|
36
36
|
getChatRequestAttributes,
|
|
37
37
|
getChatResponseAttributes,
|
|
38
38
|
} from "./otel";
|
|
39
|
-
import { ChatCompletionsBodySchema } from "./schema";
|
|
39
|
+
import { ChatCompletionsBodySchema, type ChatCompletionsBody } from "./schema";
|
|
40
40
|
|
|
41
41
|
export const chatCompletions = (config: GatewayConfig): Endpoint => {
|
|
42
42
|
const hooks = config.hooks;
|
|
@@ -57,6 +57,7 @@ export const chatCompletions = (config: GatewayConfig): Endpoint => {
|
|
|
57
57
|
} catch {
|
|
58
58
|
throw new GatewayError("Invalid JSON", 400);
|
|
59
59
|
}
|
|
60
|
+
logger.trace({ requestId: ctx.requestId, body: ctx.body }, "[chat] ChatCompletionsBody");
|
|
60
61
|
addSpanEvent("hebo.request.deserialized");
|
|
61
62
|
|
|
62
63
|
const parsed = ChatCompletionsBodySchema.safeParse(ctx.body);
|
|
@@ -68,7 +69,8 @@ export const chatCompletions = (config: GatewayConfig): Endpoint => {
|
|
|
68
69
|
addSpanEvent("hebo.request.parsed");
|
|
69
70
|
|
|
70
71
|
if (hooks?.before) {
|
|
71
|
-
ctx.body =
|
|
72
|
+
ctx.body =
|
|
73
|
+
((await hooks.before(ctx as BeforeHookContext)) as ChatCompletionsBody) ?? ctx.body;
|
|
72
74
|
addSpanEvent("hebo.hooks.before.completed");
|
|
73
75
|
}
|
|
74
76
|
|
|
@@ -110,7 +112,7 @@ export const chatCompletions = (config: GatewayConfig): Endpoint => {
|
|
|
110
112
|
"[chat] AI SDK options",
|
|
111
113
|
);
|
|
112
114
|
addSpanEvent("hebo.options.prepared");
|
|
113
|
-
setSpanAttributes(getChatRequestAttributes(
|
|
115
|
+
setSpanAttributes(getChatRequestAttributes(ctx.body, genAiSignalLevel));
|
|
114
116
|
|
|
115
117
|
// Build middleware chain (model -> forward params -> provider).
|
|
116
118
|
const languageModelWithMiddleware = wrapLanguageModel({
|
|
@@ -138,6 +140,10 @@ export const chatCompletions = (config: GatewayConfig): Endpoint => {
|
|
|
138
140
|
res as unknown as GenerateTextResult<ToolSet, Output.Output>,
|
|
139
141
|
ctx.resolvedModelId!,
|
|
140
142
|
);
|
|
143
|
+
logger.trace(
|
|
144
|
+
{ requestId: ctx.requestId, result: streamResult },
|
|
145
|
+
"[chat] ChatCompletions",
|
|
146
|
+
);
|
|
141
147
|
addSpanEvent("hebo.result.transformed");
|
|
142
148
|
|
|
143
149
|
const genAiResponseAttrs = getChatResponseAttributes(streamResult, genAiSignalLevel);
|
|
@@ -180,6 +186,7 @@ export const chatCompletions = (config: GatewayConfig): Endpoint => {
|
|
|
180
186
|
|
|
181
187
|
// Transform result.
|
|
182
188
|
ctx.result = toChatCompletions(result, ctx.resolvedModelId);
|
|
189
|
+
logger.trace({ requestId: ctx.requestId, result: ctx.result }, "[chat] ChatCompletions");
|
|
183
190
|
addSpanEvent("hebo.result.transformed");
|
|
184
191
|
|
|
185
192
|
const genAiResponseAttrs = getChatResponseAttributes(ctx.result, genAiSignalLevel);
|
|
@@ -135,20 +135,33 @@ export const ChatCompletionsToolSchema = z.object({
|
|
|
135
135
|
name: z.string(),
|
|
136
136
|
description: z.string().optional(),
|
|
137
137
|
parameters: z.record(z.string(), z.unknown()),
|
|
138
|
-
|
|
138
|
+
strict: z.boolean().optional(),
|
|
139
139
|
}),
|
|
140
140
|
});
|
|
141
141
|
export type ChatCompletionsTool = z.infer<typeof ChatCompletionsToolSchema>;
|
|
142
142
|
|
|
143
|
+
const ChatCompletionsNamedFunctionToolChoiceSchema = z.object({
|
|
144
|
+
type: z.literal("function"),
|
|
145
|
+
function: z.object({
|
|
146
|
+
name: z.string(),
|
|
147
|
+
}),
|
|
148
|
+
});
|
|
149
|
+
|
|
150
|
+
const ChatCompletionsAllowedFunctionToolChoiceSchema = z.object({
|
|
151
|
+
type: z.literal("allowed_tools"),
|
|
152
|
+
allowed_tools: z.object({
|
|
153
|
+
mode: z.enum(["auto", "required"]),
|
|
154
|
+
tools: z.array(ChatCompletionsNamedFunctionToolChoiceSchema).nonempty(),
|
|
155
|
+
}),
|
|
156
|
+
});
|
|
157
|
+
|
|
143
158
|
export const ChatCompletionsToolChoiceSchema = z.union([
|
|
144
159
|
z.enum(["none", "auto", "required", "validated"]),
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
}),
|
|
151
|
-
}),
|
|
160
|
+
z.discriminatedUnion("type", [
|
|
161
|
+
ChatCompletionsNamedFunctionToolChoiceSchema,
|
|
162
|
+
ChatCompletionsAllowedFunctionToolChoiceSchema,
|
|
163
|
+
]),
|
|
164
|
+
// FUTURE: Missing CustomTool
|
|
152
165
|
]);
|
|
153
166
|
export type ChatCompletionsToolChoice = z.infer<typeof ChatCompletionsToolChoiceSchema>;
|
|
154
167
|
|
|
@@ -193,12 +206,7 @@ export type ChatCompletionsResponseFormat = z.infer<typeof ChatCompletionsRespon
|
|
|
193
206
|
|
|
194
207
|
const ChatCompletionsInputsSchema = z.object({
|
|
195
208
|
messages: z.array(ChatCompletionsMessageSchema),
|
|
196
|
-
tools: z
|
|
197
|
-
.array(
|
|
198
|
-
// FUTURE: Missing CustomTool
|
|
199
|
-
ChatCompletionsToolSchema,
|
|
200
|
-
)
|
|
201
|
-
.optional(),
|
|
209
|
+
tools: z.array(ChatCompletionsToolSchema).optional(),
|
|
202
210
|
tool_choice: ChatCompletionsToolChoiceSchema.optional(),
|
|
203
211
|
temperature: z.number().min(0).max(2).optional(),
|
|
204
212
|
max_tokens: z.int().nonnegative().optional(),
|
|
@@ -29,7 +29,7 @@ import {
|
|
|
29
29
|
getEmbeddingsRequestAttributes,
|
|
30
30
|
getEmbeddingsResponseAttributes,
|
|
31
31
|
} from "./otel";
|
|
32
|
-
import { EmbeddingsBodySchema } from "./schema";
|
|
32
|
+
import { EmbeddingsBodySchema, type EmbeddingsBody } from "./schema";
|
|
33
33
|
|
|
34
34
|
export const embeddings = (config: GatewayConfig): Endpoint => {
|
|
35
35
|
const hooks = config.hooks;
|
|
@@ -50,6 +50,7 @@ export const embeddings = (config: GatewayConfig): Endpoint => {
|
|
|
50
50
|
} catch {
|
|
51
51
|
throw new GatewayError("Invalid JSON", 400);
|
|
52
52
|
}
|
|
53
|
+
logger.trace({ requestId: ctx.requestId, result: ctx.body }, "[chat] EmbeddingsBody");
|
|
53
54
|
addSpanEvent("hebo.request.deserialized");
|
|
54
55
|
|
|
55
56
|
const parsed = EmbeddingsBodySchema.safeParse(ctx.body);
|
|
@@ -61,7 +62,7 @@ export const embeddings = (config: GatewayConfig): Endpoint => {
|
|
|
61
62
|
addSpanEvent("hebo.request.parsed");
|
|
62
63
|
|
|
63
64
|
if (hooks?.before) {
|
|
64
|
-
ctx.body = (await hooks.before(ctx as BeforeHookContext)) ?? ctx.body;
|
|
65
|
+
ctx.body = ((await hooks.before(ctx as BeforeHookContext)) as EmbeddingsBody) ?? ctx.body;
|
|
65
66
|
addSpanEvent("hebo.hooks.before.completed");
|
|
66
67
|
}
|
|
67
68
|
|
|
@@ -100,7 +101,7 @@ export const embeddings = (config: GatewayConfig): Endpoint => {
|
|
|
100
101
|
"[embeddings] AI SDK options",
|
|
101
102
|
);
|
|
102
103
|
addSpanEvent("hebo.options.prepared");
|
|
103
|
-
setSpanAttributes(getEmbeddingsRequestAttributes(
|
|
104
|
+
setSpanAttributes(getEmbeddingsRequestAttributes(ctx.body, genAiSignalLevel));
|
|
104
105
|
|
|
105
106
|
// Build middleware chain (model -> forward params -> provider).
|
|
106
107
|
const embeddingModelWithMiddleware = wrapEmbeddingModel({
|
|
@@ -121,6 +122,7 @@ export const embeddings = (config: GatewayConfig): Endpoint => {
|
|
|
121
122
|
|
|
122
123
|
// Transform result.
|
|
123
124
|
ctx.result = toEmbeddings(result, ctx.modelId);
|
|
125
|
+
logger.trace({ requestId: ctx.requestId, result: ctx.result }, "[chat] Embeddings");
|
|
124
126
|
addSpanEvent("hebo.result.transformed");
|
|
125
127
|
const genAiResponseAttrs = getEmbeddingsResponseAttributes(ctx.result, genAiSignalLevel);
|
|
126
128
|
recordTokenUsage(genAiResponseAttrs, genAiGeneralAttrs, genAiSignalLevel);
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import type { EmbeddingModelMiddleware, LanguageModelMiddleware } from "ai";
|
|
2
|
+
|
|
3
|
+
import { logger } from "../logger";
|
|
4
|
+
|
|
5
|
+
export const debugLanguageFinalParamsMiddleware: LanguageModelMiddleware = {
|
|
6
|
+
specificationVersion: "v3",
|
|
7
|
+
// eslint-disable-next-line require-await
|
|
8
|
+
transformParams: async ({ params, model }) => {
|
|
9
|
+
logger.trace(
|
|
10
|
+
{
|
|
11
|
+
kind: "text",
|
|
12
|
+
modelId: model.modelId,
|
|
13
|
+
providerId: model.provider,
|
|
14
|
+
params,
|
|
15
|
+
},
|
|
16
|
+
"[middleware] final params",
|
|
17
|
+
);
|
|
18
|
+
return params;
|
|
19
|
+
},
|
|
20
|
+
};
|
|
21
|
+
|
|
22
|
+
export const debugEmbeddingFinalParamsMiddleware: EmbeddingModelMiddleware = {
|
|
23
|
+
specificationVersion: "v3",
|
|
24
|
+
// eslint-disable-next-line require-await
|
|
25
|
+
transformParams: async ({ params, model }) => {
|
|
26
|
+
logger.trace(
|
|
27
|
+
{
|
|
28
|
+
kind: "embedding",
|
|
29
|
+
modelId: model.modelId,
|
|
30
|
+
providerId: model.provider,
|
|
31
|
+
params,
|
|
32
|
+
},
|
|
33
|
+
"[middleware] final params",
|
|
34
|
+
);
|
|
35
|
+
return params;
|
|
36
|
+
},
|
|
37
|
+
};
|
|
@@ -6,6 +6,7 @@ import type { ProviderId } from "../providers/types";
|
|
|
6
6
|
import { logger } from "../logger";
|
|
7
7
|
import { addSpanEvent } from "../telemetry/span";
|
|
8
8
|
import { forwardParamsEmbeddingMiddleware, forwardParamsMiddleware } from "./common";
|
|
9
|
+
import { debugEmbeddingFinalParamsMiddleware, debugLanguageFinalParamsMiddleware } from "./debug";
|
|
9
10
|
|
|
10
11
|
type MiddlewareEntries = {
|
|
11
12
|
language?: LanguageModelMiddleware[];
|
|
@@ -110,6 +111,9 @@ class ModelMiddlewareMatcher {
|
|
|
110
111
|
if (providerId) {
|
|
111
112
|
out.push(...this.collect(this.provider.match(providerId), kind));
|
|
112
113
|
}
|
|
114
|
+
out.push(
|
|
115
|
+
kind === "text" ? debugLanguageFinalParamsMiddleware : debugEmbeddingFinalParamsMiddleware,
|
|
116
|
+
);
|
|
113
117
|
|
|
114
118
|
if (this.cache.size >= ModelMiddlewareMatcher.MAX_CACHE) {
|
|
115
119
|
let n = Math.ceil(ModelMiddlewareMatcher.MAX_CACHE * 0.2);
|
|
@@ -125,7 +125,7 @@ test("claudeReasoningMiddleware > should transform reasoning object to thinking
|
|
|
125
125
|
anthropic: {
|
|
126
126
|
thinking: {
|
|
127
127
|
type: "enabled",
|
|
128
|
-
budgetTokens:
|
|
128
|
+
budgetTokens: 2000,
|
|
129
129
|
},
|
|
130
130
|
},
|
|
131
131
|
unknown: {},
|
|
@@ -412,6 +412,7 @@ test("claudeReasoningMiddleware > should map none effort to low for Claude Sonne
|
|
|
412
412
|
|
|
413
413
|
expect(result.providerOptions?.anthropic?.thinking).toEqual({
|
|
414
414
|
type: "enabled",
|
|
415
|
+
budgetTokens: 1024,
|
|
415
416
|
});
|
|
416
417
|
expect(result.providerOptions?.anthropic?.effort).toBe("low");
|
|
417
418
|
});
|
|
@@ -518,6 +519,7 @@ test("claudeReasoningMiddleware > should map max effort to high for Claude Sonne
|
|
|
518
519
|
|
|
519
520
|
expect(result.providerOptions?.anthropic?.thinking).toEqual({
|
|
520
521
|
type: "enabled",
|
|
522
|
+
budgetTokens: 60800,
|
|
521
523
|
});
|
|
522
524
|
expect(result.providerOptions?.anthropic?.effort).toBe("high");
|
|
523
525
|
});
|
|
@@ -543,6 +545,7 @@ test("claudeReasoningMiddleware > should map xhigh effort to high for Claude Son
|
|
|
543
545
|
|
|
544
546
|
expect(result.providerOptions?.anthropic?.thinking).toEqual({
|
|
545
547
|
type: "enabled",
|
|
548
|
+
budgetTokens: 60800,
|
|
546
549
|
});
|
|
547
550
|
expect(result.providerOptions?.anthropic?.effort).toBe("high");
|
|
548
551
|
});
|
|
@@ -590,6 +593,7 @@ test("claudeReasoningMiddleware > should map xhigh effort for Claude Opus 4.5 wi
|
|
|
590
593
|
|
|
591
594
|
expect(result.providerOptions?.anthropic?.thinking).toEqual({
|
|
592
595
|
type: "enabled",
|
|
596
|
+
budgetTokens: 60800,
|
|
593
597
|
});
|
|
594
598
|
expect(result.providerOptions?.anthropic?.effort).toBe("high");
|
|
595
599
|
});
|
|
@@ -16,11 +16,12 @@ const isClaude = (family: "opus" | "sonnet" | "haiku", version: string) => {
|
|
|
16
16
|
modelId.includes(`claude-${family}-${dashed}`);
|
|
17
17
|
};
|
|
18
18
|
|
|
19
|
+
const isClaude4 = (modelId: string) => modelId.includes("claude-") && modelId.includes("-4");
|
|
20
|
+
|
|
19
21
|
const isOpus46 = isClaude("opus", "4.6");
|
|
20
22
|
const isOpus45 = isClaude("opus", "4.5");
|
|
21
23
|
const isOpus4 = isClaude("opus", "4");
|
|
22
24
|
const isSonnet46 = isClaude("sonnet", "4.6");
|
|
23
|
-
const isSonnet45 = isClaude("sonnet", "4.5");
|
|
24
25
|
|
|
25
26
|
export function mapClaudeReasoningEffort(effort: ChatCompletionsReasoningEffort, modelId: string) {
|
|
26
27
|
if (isOpus46(modelId)) {
|
|
@@ -60,7 +61,10 @@ function getMaxOutputTokens(modelId: string): number {
|
|
|
60
61
|
return 64_000;
|
|
61
62
|
}
|
|
62
63
|
|
|
64
|
+
// Documentation:
|
|
63
65
|
// https://platform.claude.com/docs/en/build-with-claude/effort
|
|
66
|
+
// https://platform.claude.com/docs/en/build-with-claude/extended-thinking
|
|
67
|
+
// https://platform.claude.com/docs/en/build-with-claude/adaptive-thinking
|
|
64
68
|
export const claudeReasoningMiddleware: LanguageModelMiddleware = {
|
|
65
69
|
specificationVersion: "v3",
|
|
66
70
|
// eslint-disable-next-line require-await
|
|
@@ -79,30 +83,30 @@ export const claudeReasoningMiddleware: LanguageModelMiddleware = {
|
|
|
79
83
|
if (!reasoning.enabled) {
|
|
80
84
|
target["thinking"] = { type: "disabled" };
|
|
81
85
|
} else if (reasoning.effort) {
|
|
86
|
+
if (isClaude4(modelId)) {
|
|
87
|
+
target["effort"] = mapClaudeReasoningEffort(reasoning.effort, modelId);
|
|
88
|
+
}
|
|
89
|
+
|
|
82
90
|
if (isOpus46(modelId)) {
|
|
83
91
|
target["thinking"] = clampedMaxTokens
|
|
84
92
|
? { type: "adaptive", budgetTokens: clampedMaxTokens }
|
|
85
93
|
: { type: "adaptive" };
|
|
86
|
-
target["effort"] = mapClaudeReasoningEffort(reasoning.effort, modelId);
|
|
87
94
|
} else if (isSonnet46(modelId)) {
|
|
88
95
|
target["thinking"] = clampedMaxTokens
|
|
89
96
|
? { type: "enabled", budgetTokens: clampedMaxTokens }
|
|
90
97
|
: { type: "adaptive" };
|
|
91
|
-
target["effort"] = mapClaudeReasoningEffort(reasoning.effort, modelId);
|
|
92
|
-
} else if (isOpus45(modelId) || isSonnet45(modelId)) {
|
|
93
|
-
target["thinking"] = { type: "enabled" };
|
|
94
|
-
if (clampedMaxTokens) target["thinking"]["budgetTokens"] = clampedMaxTokens;
|
|
95
|
-
target["effort"] = mapClaudeReasoningEffort(reasoning.effort, modelId);
|
|
96
98
|
} else {
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
99
|
+
target["thinking"] = { type: "enabled" };
|
|
100
|
+
if (clampedMaxTokens) {
|
|
101
|
+
target["thinking"]["budgetTokens"] = clampedMaxTokens;
|
|
102
|
+
} else {
|
|
103
|
+
// FUTURE: warn that reasoning.max_tokens was computed
|
|
104
|
+
target["thinking"]["budgetTokens"] = calculateReasoningBudgetFromEffort(
|
|
101
105
|
reasoning.effort,
|
|
102
106
|
params.maxOutputTokens ?? getMaxOutputTokens(modelId),
|
|
103
107
|
1024,
|
|
104
|
-
)
|
|
105
|
-
}
|
|
108
|
+
);
|
|
109
|
+
}
|
|
106
110
|
}
|
|
107
111
|
} else if (clampedMaxTokens) {
|
|
108
112
|
target["thinking"] = {
|
|
@@ -2,19 +2,73 @@ import { MockLanguageModelV3 } from "ai/test";
|
|
|
2
2
|
import { expect, test } from "bun:test";
|
|
3
3
|
|
|
4
4
|
import { modelMiddlewareMatcher } from "../../middleware/matcher";
|
|
5
|
-
import {
|
|
5
|
+
import { bedrockClaudeReasoningMiddleware, bedrockGptReasoningMiddleware } from "./middleware";
|
|
6
6
|
|
|
7
|
-
test("
|
|
7
|
+
test("bedrock middlewares > matching provider resolves GPT middleware", () => {
|
|
8
|
+
const middleware = modelMiddlewareMatcher.resolve({
|
|
9
|
+
kind: "text",
|
|
10
|
+
modelId: "openai/gpt-oss-20b",
|
|
11
|
+
providerId: "amazon-bedrock",
|
|
12
|
+
});
|
|
13
|
+
|
|
14
|
+
expect(middleware).toContain(bedrockGptReasoningMiddleware);
|
|
15
|
+
});
|
|
16
|
+
|
|
17
|
+
test("bedrock middlewares > matching provider resolves Claude middleware", () => {
|
|
8
18
|
const middleware = modelMiddlewareMatcher.resolve({
|
|
9
19
|
kind: "text",
|
|
10
20
|
modelId: "anthropic/claude-opus-4.6",
|
|
11
21
|
providerId: "amazon-bedrock",
|
|
12
22
|
});
|
|
13
23
|
|
|
14
|
-
expect(middleware).toContain(
|
|
24
|
+
expect(middleware).toContain(bedrockClaudeReasoningMiddleware);
|
|
15
25
|
});
|
|
16
26
|
|
|
17
|
-
test("
|
|
27
|
+
test("bedrockGptReasoningMiddleware > should map reasoningEffort into reasoningConfig", async () => {
|
|
28
|
+
const params = {
|
|
29
|
+
prompt: [],
|
|
30
|
+
providerOptions: {
|
|
31
|
+
bedrock: {
|
|
32
|
+
reasoningEffort: "high",
|
|
33
|
+
},
|
|
34
|
+
},
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
const result = await bedrockGptReasoningMiddleware.transformParams!({
|
|
38
|
+
type: "generate",
|
|
39
|
+
params,
|
|
40
|
+
model: new MockLanguageModelV3({ modelId: "openai/gpt-oss-20b" }),
|
|
41
|
+
});
|
|
42
|
+
|
|
43
|
+
expect(result.providerOptions?.bedrock).toEqual({
|
|
44
|
+
reasoningConfig: {
|
|
45
|
+
maxReasoningEffort: "high",
|
|
46
|
+
},
|
|
47
|
+
});
|
|
48
|
+
});
|
|
49
|
+
|
|
50
|
+
test("bedrockGptReasoningMiddleware > should skip non-gpt models", async () => {
|
|
51
|
+
const params = {
|
|
52
|
+
prompt: [],
|
|
53
|
+
providerOptions: {
|
|
54
|
+
bedrock: {
|
|
55
|
+
reasoningEffort: "medium",
|
|
56
|
+
},
|
|
57
|
+
},
|
|
58
|
+
};
|
|
59
|
+
|
|
60
|
+
const result = await bedrockGptReasoningMiddleware.transformParams!({
|
|
61
|
+
type: "generate",
|
|
62
|
+
params,
|
|
63
|
+
model: new MockLanguageModelV3({ modelId: "anthropic/claude-opus-4.6" }),
|
|
64
|
+
});
|
|
65
|
+
|
|
66
|
+
expect(result.providerOptions?.bedrock).toEqual({
|
|
67
|
+
reasoningEffort: "medium",
|
|
68
|
+
});
|
|
69
|
+
});
|
|
70
|
+
|
|
71
|
+
test("bedrockClaudeReasoningMiddleware > should map thinking/effort into reasoningConfig", async () => {
|
|
18
72
|
const params = {
|
|
19
73
|
prompt: [],
|
|
20
74
|
providerOptions: {
|
|
@@ -28,10 +82,10 @@ test("bedrockAnthropicReasoningMiddleware > should map thinking/effort into reas
|
|
|
28
82
|
},
|
|
29
83
|
};
|
|
30
84
|
|
|
31
|
-
const result = await
|
|
85
|
+
const result = await bedrockClaudeReasoningMiddleware.transformParams!({
|
|
32
86
|
type: "generate",
|
|
33
87
|
params,
|
|
34
|
-
model: new MockLanguageModelV3({ modelId: "anthropic/claude-opus-4
|
|
88
|
+
model: new MockLanguageModelV3({ modelId: "anthropic/claude-opus-4-6" }),
|
|
35
89
|
});
|
|
36
90
|
|
|
37
91
|
expect(result.providerOptions?.bedrock).toEqual({
|
|
@@ -43,7 +97,7 @@ test("bedrockAnthropicReasoningMiddleware > should map thinking/effort into reas
|
|
|
43
97
|
});
|
|
44
98
|
});
|
|
45
99
|
|
|
46
|
-
test("
|
|
100
|
+
test("bedrockClaudeReasoningMiddleware > should skip non-claude models", async () => {
|
|
47
101
|
const params = {
|
|
48
102
|
prompt: [],
|
|
49
103
|
providerOptions: {
|
|
@@ -57,7 +111,7 @@ test("bedrockAnthropicReasoningMiddleware > should skip non-anthropic models", a
|
|
|
57
111
|
},
|
|
58
112
|
};
|
|
59
113
|
|
|
60
|
-
const result = await
|
|
114
|
+
const result = await bedrockClaudeReasoningMiddleware.transformParams!({
|
|
61
115
|
type: "generate",
|
|
62
116
|
params,
|
|
63
117
|
model: new MockLanguageModelV3({ modelId: "openai/gpt-oss-20b" }),
|
|
@@ -71,3 +125,59 @@ test("bedrockAnthropicReasoningMiddleware > should skip non-anthropic models", a
|
|
|
71
125
|
effort: "high",
|
|
72
126
|
});
|
|
73
127
|
});
|
|
128
|
+
|
|
129
|
+
test("bedrockClaudeReasoningMiddleware > should not set maxReasoningEffort for Claude 3.x", async () => {
|
|
130
|
+
const params = {
|
|
131
|
+
prompt: [],
|
|
132
|
+
providerOptions: {
|
|
133
|
+
bedrock: {
|
|
134
|
+
thinking: {
|
|
135
|
+
type: "enabled",
|
|
136
|
+
budgetTokens: 4096,
|
|
137
|
+
},
|
|
138
|
+
effort: "high",
|
|
139
|
+
},
|
|
140
|
+
},
|
|
141
|
+
};
|
|
142
|
+
|
|
143
|
+
const result = await bedrockClaudeReasoningMiddleware.transformParams!({
|
|
144
|
+
type: "generate",
|
|
145
|
+
params,
|
|
146
|
+
model: new MockLanguageModelV3({ modelId: "anthropic/claude-sonnet-3.7" }),
|
|
147
|
+
});
|
|
148
|
+
|
|
149
|
+
expect(result.providerOptions?.bedrock).toEqual({
|
|
150
|
+
reasoningConfig: {
|
|
151
|
+
type: "enabled",
|
|
152
|
+
budgetTokens: 4096,
|
|
153
|
+
},
|
|
154
|
+
});
|
|
155
|
+
});
|
|
156
|
+
|
|
157
|
+
test("bedrockClaudeReasoningMiddleware > should not set maxReasoningEffort for Claude 4.5", async () => {
|
|
158
|
+
const params = {
|
|
159
|
+
prompt: [],
|
|
160
|
+
providerOptions: {
|
|
161
|
+
bedrock: {
|
|
162
|
+
thinking: {
|
|
163
|
+
type: "enabled",
|
|
164
|
+
budgetTokens: 4096,
|
|
165
|
+
},
|
|
166
|
+
effort: "high",
|
|
167
|
+
},
|
|
168
|
+
},
|
|
169
|
+
};
|
|
170
|
+
|
|
171
|
+
const result = await bedrockClaudeReasoningMiddleware.transformParams!({
|
|
172
|
+
type: "generate",
|
|
173
|
+
params,
|
|
174
|
+
model: new MockLanguageModelV3({ modelId: "anthropic/claude-opus-4.5" }),
|
|
175
|
+
});
|
|
176
|
+
|
|
177
|
+
expect(result.providerOptions?.bedrock).toEqual({
|
|
178
|
+
reasoningConfig: {
|
|
179
|
+
type: "enabled",
|
|
180
|
+
budgetTokens: 4096,
|
|
181
|
+
},
|
|
182
|
+
});
|
|
183
|
+
});
|