workers-ai-provider 0.6.3 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/utils.ts CHANGED
@@ -66,6 +66,7 @@ export function createRun(config: CreateRunConfig): AiRun {
66
66
  inputs: AiModels[Name]["inputs"],
67
67
  options?: AiOptions & Record<string, StringLike>,
68
68
  ): Promise<Response | ReadableStream<Uint8Array> | AiModels[Name]["postProcessedOutputs"]> {
69
+ // biome-ignore lint/correctness/noUnusedVariables: they need to be destructured
69
70
  const { gateway, prefix, extraHeaders, returnRawResponse, ...passthroughOptions } =
70
71
  options || {};
71
72
 
@@ -78,7 +79,7 @@ export function createRun(config: CreateRunConfig): AiRun {
78
79
  continue;
79
80
  }
80
81
  urlParams.append(key, valueStr);
81
- } catch (error) {
82
+ } catch (_error) {
82
83
  throw new Error(
83
84
  `Value for option '${key}' is not able to be coerced into a string.`,
84
85
  );
@@ -91,17 +92,17 @@ export function createRun(config: CreateRunConfig): AiRun {
91
92
 
92
93
  // Merge default and custom headers.
93
94
  const headers = {
94
- "Content-Type": "application/json",
95
95
  Authorization: `Bearer ${apiKey}`,
96
+ "Content-Type": "application/json",
96
97
  };
97
98
 
98
99
  const body = JSON.stringify(inputs);
99
100
 
100
101
  // Execute the POST request. The optional AbortSignal is applied here.
101
102
  const response = await fetch(url, {
102
- method: "POST",
103
- headers,
104
103
  body,
104
+ headers,
105
+ method: "POST",
105
106
  });
106
107
 
107
108
  // (1) If the user explicitly requests the raw Response, return it as-is.
@@ -134,42 +135,42 @@ export function prepareToolsAndToolChoice(
134
135
  const tools = mode.tools?.length ? mode.tools : undefined;
135
136
 
136
137
  if (tools == null) {
137
- return { tools: undefined, tool_choice: undefined };
138
+ return { tool_choice: undefined, tools: undefined };
138
139
  }
139
140
 
140
141
  const mappedTools = tools.map((tool) => ({
141
- type: "function",
142
142
  function: {
143
- name: tool.name,
144
143
  // @ts-expect-error - description is not a property of tool
145
144
  description: tool.description,
145
+ name: tool.name,
146
146
  // @ts-expect-error - parameters is not a property of tool
147
147
  parameters: tool.parameters,
148
148
  },
149
+ type: "function",
149
150
  }));
150
151
 
151
152
  const toolChoice = mode.toolChoice;
152
153
 
153
154
  if (toolChoice == null) {
154
- return { tools: mappedTools, tool_choice: undefined };
155
+ return { tool_choice: undefined, tools: mappedTools };
155
156
  }
156
157
 
157
158
  const type = toolChoice.type;
158
159
 
159
160
  switch (type) {
160
161
  case "auto":
161
- return { tools: mappedTools, tool_choice: type };
162
+ return { tool_choice: type, tools: mappedTools };
162
163
  case "none":
163
- return { tools: mappedTools, tool_choice: type };
164
+ return { tool_choice: type, tools: mappedTools };
164
165
  case "required":
165
- return { tools: mappedTools, tool_choice: "any" };
166
+ return { tool_choice: "any", tools: mappedTools };
166
167
 
167
168
  // workersAI does not support tool mode directly,
168
169
  // so we filter the tools and force the tool choice through 'any'
169
170
  case "tool":
170
171
  return {
171
- tools: mappedTools.filter((tool) => tool.function.name === toolChoice.toolName),
172
172
  tool_choice: "any",
173
+ tools: mappedTools.filter((tool) => tool.function.name === toolChoice.toolName),
173
174
  };
174
175
  default: {
175
176
  const exhaustiveCheck = type satisfies never;
@@ -190,12 +191,12 @@ function mergePartialToolCalls(partialCalls: any[]) {
190
191
 
191
192
  if (!mergedCallsByIndex[index]) {
192
193
  mergedCallsByIndex[index] = {
193
- id: partialCall.id || "",
194
- type: partialCall.type || "",
195
194
  function: {
196
- name: partialCall.function?.name || "",
197
195
  arguments: "",
196
+ name: partialCall.function?.name || "",
198
197
  },
198
+ id: partialCall.id || "",
199
+ type: partialCall.type || "",
199
200
  };
200
201
  } else {
201
202
  if (partialCall.id) {
@@ -223,23 +224,23 @@ function processToolCall(toolCall: any): LanguageModelV1FunctionToolCall {
223
224
  // Check for OpenAI format tool calls first
224
225
  if (toolCall.function && toolCall.id) {
225
226
  return {
226
- toolCallType: "function",
227
- toolCallId: toolCall.id,
228
- toolName: toolCall.function.name,
229
227
  args:
230
228
  typeof toolCall.function.arguments === "string"
231
229
  ? toolCall.function.arguments
232
230
  : JSON.stringify(toolCall.function.arguments || {}),
231
+ toolCallId: toolCall.id,
232
+ toolCallType: "function",
233
+ toolName: toolCall.function.name,
233
234
  };
234
235
  }
235
236
  return {
236
- toolCallType: "function",
237
- toolCallId: toolCall.name,
238
- toolName: toolCall.name,
239
237
  args:
240
238
  typeof toolCall.arguments === "string"
241
239
  ? toolCall.arguments
242
240
  : JSON.stringify(toolCall.arguments || {}),
241
+ toolCallId: toolCall.name,
242
+ toolCallType: "function",
243
+ toolName: toolCall.name,
243
244
  };
244
245
  }
245
246
 
@@ -251,6 +252,16 @@ export function processToolCalls(output: any): LanguageModelV1FunctionToolCall[]
251
252
  });
252
253
  }
253
254
 
255
+ if (
256
+ output?.choices?.[0]?.message?.tool_calls &&
257
+ Array.isArray(output.choices[0].message.tool_calls)
258
+ ) {
259
+ return output.choices[0].message.tool_calls.map((toolCall: any) => {
260
+ const processedToolCall = processToolCall(toolCall);
261
+ return processedToolCall;
262
+ });
263
+ }
264
+
254
265
  return [];
255
266
  }
256
267
 
@@ -1,4 +1,4 @@
1
- import { TooManyEmbeddingValuesForCallError, type EmbeddingModelV1 } from "@ai-sdk/provider";
1
+ import { type EmbeddingModelV1, TooManyEmbeddingValuesForCallError } from "@ai-sdk/provider";
2
2
  import type { StringLike } from "./utils";
3
3
  import type { EmbeddingModels } from "./workersai-models";
4
4
 
@@ -63,9 +63,9 @@ export class WorkersAIEmbeddingModel implements EmbeddingModelV1<string> {
63
63
  > {
64
64
  if (values.length > this.maxEmbeddingsPerCall) {
65
65
  throw new TooManyEmbeddingValuesForCallError({
66
- provider: this.provider,
67
- modelId: this.modelId,
68
66
  maxEmbeddingsPerCall: this.maxEmbeddingsPerCall,
67
+ modelId: this.modelId,
68
+ provider: this.provider,
69
69
  values,
70
70
  });
71
71
  }
@@ -5,13 +5,12 @@ import {
5
5
  UnsupportedFunctionalityError,
6
6
  } from "@ai-sdk/provider";
7
7
  import { convertToWorkersAIChatMessages } from "./convert-to-workersai-chat-messages";
8
- import type { WorkersAIChatSettings } from "./workersai-chat-settings";
9
- import type { TextGenerationModels } from "./workersai-models";
10
-
8
+ import { mapWorkersAIFinishReason } from "./map-workersai-finish-reason";
11
9
  import { mapWorkersAIUsage } from "./map-workersai-usage";
12
10
  import { getMappedStream } from "./streaming";
13
11
  import { lastMessageWasUser, prepareToolsAndToolChoice, processToolCalls } from "./utils";
14
- import { mapWorkersAIFinishReason } from "./map-workersai-finish-reason";
12
+ import type { WorkersAIChatSettings } from "./workersai-chat-settings";
13
+ import type { TextGenerationModels } from "./workersai-models";
15
14
 
16
15
  type WorkersAIChatConfig = {
17
16
  provider: string;
@@ -57,30 +56,29 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 {
57
56
 
58
57
  if (frequencyPenalty != null) {
59
58
  warnings.push({
60
- type: "unsupported-setting",
61
59
  setting: "frequencyPenalty",
60
+ type: "unsupported-setting",
62
61
  });
63
62
  }
64
63
 
65
64
  if (presencePenalty != null) {
66
65
  warnings.push({
67
- type: "unsupported-setting",
68
66
  setting: "presencePenalty",
67
+ type: "unsupported-setting",
69
68
  });
70
69
  }
71
70
 
72
71
  const baseArgs = {
72
+ // standardized settings:
73
+ max_tokens: maxTokens,
73
74
  // model id:
74
75
  model: this.modelId,
76
+ random_seed: seed,
75
77
 
76
78
  // model specific settings:
77
79
  safe_prompt: this.settings.safePrompt,
78
-
79
- // standardized settings:
80
- max_tokens: maxTokens,
81
80
  temperature,
82
81
  top_p: topP,
83
- random_seed: seed,
84
82
  };
85
83
 
86
84
  switch (type) {
@@ -96,8 +94,8 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 {
96
94
  args: {
97
95
  ...baseArgs,
98
96
  response_format: {
99
- type: "json_schema",
100
97
  json_schema: mode.schema,
98
+ type: "json_schema",
101
99
  },
102
100
  tools: undefined,
103
101
  },
@@ -110,7 +108,7 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 {
110
108
  args: {
111
109
  ...baseArgs,
112
110
  tool_choice: "any",
113
- tools: [{ type: "function", function: mode.tool }],
111
+ tools: [{ function: mode.tool, type: "function" }],
114
112
  },
115
113
  warnings,
116
114
  };
@@ -136,6 +134,7 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 {
136
134
  ): Promise<Awaited<ReturnType<LanguageModelV1["doGenerate"]>>> {
137
135
  const { args, warnings } = this.getArgs(options);
138
136
 
137
+ // biome-ignore lint/correctness/noUnusedVariables: this needs to be destructured
139
138
  const { gateway, safePrompt, ...passthroughOptions } = this.settings;
140
139
 
141
140
  // Extract image from messages if present
@@ -151,8 +150,8 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 {
151
150
  const output = await this.config.binding.run(
152
151
  args.model,
153
152
  {
154
- messages: messages,
155
153
  max_tokens: args.max_tokens,
154
+ messages: messages,
156
155
  temperature: args.temperature,
157
156
  tools: args.tools,
158
157
  top_p: args.top_p,
@@ -170,14 +169,16 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 {
170
169
  }
171
170
 
172
171
  return {
172
+ finishReason: mapWorkersAIFinishReason(output),
173
+ rawCall: { rawPrompt: messages, rawSettings: args },
174
+ rawResponse: { body: output },
173
175
  text:
174
176
  typeof output.response === "object" && output.response !== null
175
177
  ? JSON.stringify(output.response) // ai-sdk expects a string here
176
178
  : output.response,
177
179
  toolCalls: processToolCalls(output),
178
- finishReason: mapWorkersAIFinishReason(output),
179
- rawCall: { rawPrompt: messages, rawSettings: args },
180
- rawResponse: { body: output },
180
+ // @ts-ignore: Missing types
181
+ reasoning: output?.choices?.[0]?.message?.reasoning_content,
181
182
  usage: mapWorkersAIUsage(output),
182
183
  warnings,
183
184
  };
@@ -202,12 +203,13 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 {
202
203
  }
203
204
 
204
205
  return {
206
+ rawCall: { rawPrompt: messages, rawSettings: args },
205
207
  stream: new ReadableStream<LanguageModelV1StreamPart>({
206
208
  async start(controller) {
207
209
  if (response.text) {
208
210
  controller.enqueue({
209
- type: "text-delta",
210
211
  textDelta: response.text,
212
+ type: "text-delta",
211
213
  });
212
214
  }
213
215
  if (response.toolCalls) {
@@ -218,15 +220,20 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 {
218
220
  });
219
221
  }
220
222
  }
223
+ if (response.reasoning && typeof response.reasoning === "string") {
224
+ controller.enqueue({
225
+ type: "reasoning",
226
+ textDelta: response.reasoning,
227
+ });
228
+ }
221
229
  controller.enqueue({
222
- type: "finish",
223
230
  finishReason: mapWorkersAIFinishReason(response),
231
+ type: "finish",
224
232
  usage: response.usage,
225
233
  });
226
234
  controller.close();
227
235
  },
228
236
  }),
229
- rawCall: { rawPrompt: messages, rawSettings: args },
230
237
  warnings,
231
238
  };
232
239
  }
@@ -244,8 +251,8 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 {
244
251
  const response = await this.config.binding.run(
245
252
  args.model,
246
253
  {
247
- messages: messages,
248
254
  max_tokens: args.max_tokens,
255
+ messages: messages,
249
256
  stream: true,
250
257
  temperature: args.temperature,
251
258
  tools: args.tools,
@@ -264,8 +271,8 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 {
264
271
  }
265
272
 
266
273
  return {
267
- stream: getMappedStream(new Response(response)),
268
274
  rawCall: { rawPrompt: messages, rawSettings: args },
275
+ stream: getMappedStream(new Response(response)),
269
276
  warnings,
270
277
  };
271
278
  }
@@ -9,7 +9,6 @@ export type WorkersAIChatSettings = {
9
9
 
10
10
  /**
11
11
  * Optionally set Cloudflare AI Gateway options.
12
- * @deprecated
13
12
  */
14
13
  gateway?: GatewayOptions;
15
14
  } & {
@@ -2,11 +2,11 @@ import { createJsonErrorResponseHandler } from "@ai-sdk/provider-utils";
2
2
  import { z } from "zod";
3
3
 
4
4
  const workersAIErrorDataSchema = z.object({
5
- object: z.literal("error"),
5
+ code: z.string().nullable(),
6
6
  message: z.string(),
7
- type: z.string(),
7
+ object: z.literal("error"),
8
8
  param: z.string().nullable(),
9
- code: z.string().nullable(),
9
+ type: z.string(),
10
10
  });
11
11
 
12
12
  export const workersAIFailedResponseHandler = createJsonErrorResponseHandler({
@@ -36,9 +36,9 @@ export class WorkersAIImageModel implements ImageModelV1 {
36
36
 
37
37
  if (aspectRatio != null) {
38
38
  warnings.push({
39
- type: "unsupported-setting",
40
- setting: "aspectRatio",
41
39
  details: "This model does not support aspect ratio. Use `size` instead.",
40
+ setting: "aspectRatio",
41
+ type: "unsupported-setting",
42
42
  });
43
43
  }
44
44
 
@@ -46,10 +46,10 @@ export class WorkersAIImageModel implements ImageModelV1 {
46
46
  const outputStream: ReadableStream<Uint8Array> = await this.config.binding.run(
47
47
  this.modelId,
48
48
  {
49
+ height,
49
50
  prompt,
50
51
  seed,
51
52
  width,
52
- height,
53
53
  },
54
54
  );
55
55
 
@@ -65,12 +65,12 @@ export class WorkersAIImageModel implements ImageModelV1 {
65
65
 
66
66
  return {
67
67
  images,
68
- warnings,
69
68
  response: {
70
- timestamp: new Date(),
71
- modelId: this.modelId,
72
69
  headers: {},
70
+ modelId: this.modelId,
71
+ timestamp: new Date(),
73
72
  },
73
+ warnings,
74
74
  };
75
75
  }
76
76
  }
@@ -79,8 +79,8 @@ function getDimensionsFromSizeString(size: string | undefined) {
79
79
  const [width, height] = size?.split("x") ?? [undefined, undefined];
80
80
 
81
81
  return {
82
- width: parseInteger(width),
83
82
  height: parseInteger(height),
83
+ width: parseInteger(width),
84
84
  };
85
85
  }
86
86