workers-ai-provider 0.4.0 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +42 -1
- package/dist/index.d.ts +94 -8
- package/dist/index.js +365 -167
- package/dist/index.js.map +1 -1
- package/package.json +3 -3
- package/src/autorag-chat-language-model.ts +172 -0
- package/src/autorag-chat-settings.ts +14 -0
- package/src/convert-to-workersai-chat-messages.ts +8 -18
- package/src/index.ts +74 -1
- package/src/streaming.ts +37 -0
- package/src/utils.ts +93 -1
- package/src/workers-ai-embedding-model.ts +87 -0
- package/src/workersai-chat-language-model.ts +5 -128
- package/src/workersai-models.ts +5 -0
package/src/streaming.ts
ADDED
@@ -0,0 +1,37 @@
|
|
1
|
+
import { events } from "fetch-event-stream";
|
2
|
+
|
3
|
+
import type { LanguageModelV1StreamPart } from "@ai-sdk/provider";
|
4
|
+
import { mapWorkersAIUsage } from "./map-workersai-usage";
|
5
|
+
|
6
|
+
export function getMappedStream(response: Response) {
|
7
|
+
const chunkEvent = events(response);
|
8
|
+
let usage = { promptTokens: 0, completionTokens: 0 };
|
9
|
+
|
10
|
+
return new ReadableStream<LanguageModelV1StreamPart>({
|
11
|
+
async start(controller) {
|
12
|
+
for await (const event of chunkEvent) {
|
13
|
+
if (!event.data) {
|
14
|
+
continue;
|
15
|
+
}
|
16
|
+
if (event.data === "[DONE]") {
|
17
|
+
break;
|
18
|
+
}
|
19
|
+
const chunk = JSON.parse(event.data);
|
20
|
+
if (chunk.usage) {
|
21
|
+
usage = mapWorkersAIUsage(chunk);
|
22
|
+
}
|
23
|
+
chunk.response?.length &&
|
24
|
+
controller.enqueue({
|
25
|
+
type: "text-delta",
|
26
|
+
textDelta: chunk.response,
|
27
|
+
});
|
28
|
+
}
|
29
|
+
controller.enqueue({
|
30
|
+
type: "finish",
|
31
|
+
finishReason: "stop",
|
32
|
+
usage: usage,
|
33
|
+
});
|
34
|
+
controller.close();
|
35
|
+
},
|
36
|
+
});
|
37
|
+
}
|
package/src/utils.ts
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
import type { LanguageModelV1 } from "@ai-sdk/provider";
|
2
|
+
|
1
3
|
/**
|
2
4
|
* General AI run interface with overloads to handle distinct return types.
|
3
5
|
*
|
@@ -83,7 +85,9 @@ export function createRun(config: CreateRunConfig): AiRun {
|
|
83
85
|
}
|
84
86
|
}
|
85
87
|
|
86
|
-
const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}${
|
88
|
+
const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}${
|
89
|
+
urlParams ? `?${urlParams}` : ""
|
90
|
+
}`;
|
87
91
|
|
88
92
|
// Merge default and custom headers.
|
89
93
|
const headers = {
|
@@ -120,3 +124,91 @@ export function createRun(config: CreateRunConfig): AiRun {
|
|
120
124
|
return data.result;
|
121
125
|
};
|
122
126
|
}
|
127
|
+
|
128
|
+
export function prepareToolsAndToolChoice(
|
129
|
+
mode: Parameters<LanguageModelV1["doGenerate"]>[0]["mode"] & {
|
130
|
+
type: "regular";
|
131
|
+
},
|
132
|
+
) {
|
133
|
+
// when the tools array is empty, change it to undefined to prevent errors:
|
134
|
+
const tools = mode.tools?.length ? mode.tools : undefined;
|
135
|
+
|
136
|
+
if (tools == null) {
|
137
|
+
return { tools: undefined, tool_choice: undefined };
|
138
|
+
}
|
139
|
+
|
140
|
+
const mappedTools = tools.map((tool) => ({
|
141
|
+
type: "function",
|
142
|
+
function: {
|
143
|
+
name: tool.name,
|
144
|
+
// @ts-expect-error - description is not a property of tool
|
145
|
+
description: tool.description,
|
146
|
+
// @ts-expect-error - parameters is not a property of tool
|
147
|
+
parameters: tool.parameters,
|
148
|
+
},
|
149
|
+
}));
|
150
|
+
|
151
|
+
const toolChoice = mode.toolChoice;
|
152
|
+
|
153
|
+
if (toolChoice == null) {
|
154
|
+
return { tools: mappedTools, tool_choice: undefined };
|
155
|
+
}
|
156
|
+
|
157
|
+
const type = toolChoice.type;
|
158
|
+
|
159
|
+
switch (type) {
|
160
|
+
case "auto":
|
161
|
+
return { tools: mappedTools, tool_choice: type };
|
162
|
+
case "none":
|
163
|
+
return { tools: mappedTools, tool_choice: type };
|
164
|
+
case "required":
|
165
|
+
return { tools: mappedTools, tool_choice: "any" };
|
166
|
+
|
167
|
+
// workersAI does not support tool mode directly,
|
168
|
+
// so we filter the tools and force the tool choice through 'any'
|
169
|
+
case "tool":
|
170
|
+
return {
|
171
|
+
tools: mappedTools.filter((tool) => tool.function.name === toolChoice.toolName),
|
172
|
+
tool_choice: "any",
|
173
|
+
};
|
174
|
+
default: {
|
175
|
+
const exhaustiveCheck = type satisfies never;
|
176
|
+
throw new Error(`Unsupported tool choice type: ${exhaustiveCheck}`);
|
177
|
+
}
|
178
|
+
}
|
179
|
+
}
|
180
|
+
|
181
|
+
export function lastMessageWasUser<T extends { role: string }>(messages: T[]) {
|
182
|
+
return messages.length > 0 && messages[messages.length - 1]!.role === "user";
|
183
|
+
}
|
184
|
+
|
185
|
+
export function processToolCalls(output: any) {
|
186
|
+
// Check for OpenAI format tool calls first
|
187
|
+
if (output.tool_calls && Array.isArray(output.tool_calls)) {
|
188
|
+
return output.tool_calls.map((toolCall: any) => {
|
189
|
+
// Handle new format
|
190
|
+
if (toolCall.function && toolCall.id) {
|
191
|
+
return {
|
192
|
+
toolCallType: "function",
|
193
|
+
toolCallId: toolCall.id,
|
194
|
+
toolName: toolCall.function.name,
|
195
|
+
args:
|
196
|
+
typeof toolCall.function.arguments === "string"
|
197
|
+
? toolCall.function.arguments
|
198
|
+
: JSON.stringify(toolCall.function.arguments || {}),
|
199
|
+
};
|
200
|
+
}
|
201
|
+
return {
|
202
|
+
toolCallType: "function",
|
203
|
+
toolCallId: toolCall.name,
|
204
|
+
toolName: toolCall.name,
|
205
|
+
args:
|
206
|
+
typeof toolCall.arguments === "string"
|
207
|
+
? toolCall.arguments
|
208
|
+
: JSON.stringify(toolCall.arguments || {}),
|
209
|
+
};
|
210
|
+
});
|
211
|
+
}
|
212
|
+
|
213
|
+
return [];
|
214
|
+
}
|
@@ -0,0 +1,87 @@
|
|
1
|
+
import { TooManyEmbeddingValuesForCallError, type EmbeddingModelV1 } from "@ai-sdk/provider";
|
2
|
+
import type { StringLike } from "./utils";
|
3
|
+
import type { EmbeddingModels } from "./workersai-models";
|
4
|
+
|
5
|
+
export type WorkersAIEmbeddingConfig = {
|
6
|
+
provider: string;
|
7
|
+
binding: Ai;
|
8
|
+
gateway?: GatewayOptions;
|
9
|
+
};
|
10
|
+
|
11
|
+
export type WorkersAIEmbeddingSettings = {
|
12
|
+
gateway?: GatewayOptions;
|
13
|
+
maxEmbeddingsPerCall?: number;
|
14
|
+
supportsParallelCalls?: boolean;
|
15
|
+
} & {
|
16
|
+
/**
|
17
|
+
* Arbitrary provider-specific options forwarded unmodified.
|
18
|
+
*/
|
19
|
+
[key: string]: StringLike;
|
20
|
+
};
|
21
|
+
|
22
|
+
export class WorkersAIEmbeddingModel implements EmbeddingModelV1<string> {
|
23
|
+
/**
|
24
|
+
* Semantic version of the {@link EmbeddingModelV1} specification implemented
|
25
|
+
* by this class. It never changes.
|
26
|
+
*/
|
27
|
+
readonly specificationVersion = "v1";
|
28
|
+
readonly modelId: EmbeddingModels;
|
29
|
+
private readonly config: WorkersAIEmbeddingConfig;
|
30
|
+
private readonly settings: WorkersAIEmbeddingSettings;
|
31
|
+
|
32
|
+
/**
|
33
|
+
* Provider name exposed for diagnostics and error reporting.
|
34
|
+
*/
|
35
|
+
get provider(): string {
|
36
|
+
return this.config.provider;
|
37
|
+
}
|
38
|
+
|
39
|
+
get maxEmbeddingsPerCall(): number {
|
40
|
+
// https://developers.cloudflare.com/workers-ai/platform/limits/#text-embeddings
|
41
|
+
const maxEmbeddingsPerCall = this.modelId === "@cf/baai/bge-large-en-v1.5" ? 1500 : 3000;
|
42
|
+
return this.settings.maxEmbeddingsPerCall ?? maxEmbeddingsPerCall;
|
43
|
+
}
|
44
|
+
|
45
|
+
get supportsParallelCalls(): boolean {
|
46
|
+
return this.settings.supportsParallelCalls ?? true;
|
47
|
+
}
|
48
|
+
|
49
|
+
constructor(
|
50
|
+
modelId: EmbeddingModels,
|
51
|
+
settings: WorkersAIEmbeddingSettings,
|
52
|
+
config: WorkersAIEmbeddingConfig,
|
53
|
+
) {
|
54
|
+
this.modelId = modelId;
|
55
|
+
this.settings = settings;
|
56
|
+
this.config = config;
|
57
|
+
}
|
58
|
+
|
59
|
+
async doEmbed({
|
60
|
+
values,
|
61
|
+
}: Parameters<EmbeddingModelV1<string>["doEmbed"]>[0]): Promise<
|
62
|
+
Awaited<ReturnType<EmbeddingModelV1<string>["doEmbed"]>>
|
63
|
+
> {
|
64
|
+
if (values.length > this.maxEmbeddingsPerCall) {
|
65
|
+
throw new TooManyEmbeddingValuesForCallError({
|
66
|
+
provider: this.provider,
|
67
|
+
modelId: this.modelId,
|
68
|
+
maxEmbeddingsPerCall: this.maxEmbeddingsPerCall,
|
69
|
+
values,
|
70
|
+
});
|
71
|
+
}
|
72
|
+
|
73
|
+
const { gateway, ...passthroughOptions } = this.settings;
|
74
|
+
|
75
|
+
const response = await this.config.binding.run(
|
76
|
+
this.modelId,
|
77
|
+
{
|
78
|
+
text: values,
|
79
|
+
},
|
80
|
+
{ gateway: this.config.gateway ?? gateway, ...passthroughOptions },
|
81
|
+
);
|
82
|
+
|
83
|
+
return {
|
84
|
+
embeddings: response.data,
|
85
|
+
};
|
86
|
+
}
|
87
|
+
}
|
@@ -8,9 +8,9 @@ import { convertToWorkersAIChatMessages } from "./convert-to-workersai-chat-mess
|
|
8
8
|
import type { WorkersAIChatSettings } from "./workersai-chat-settings";
|
9
9
|
import type { TextGenerationModels } from "./workersai-models";
|
10
10
|
|
11
|
-
import { events } from "fetch-event-stream";
|
12
11
|
import { mapWorkersAIUsage } from "./map-workersai-usage";
|
13
|
-
import
|
12
|
+
import { getMappedStream } from "./streaming";
|
13
|
+
import { lastMessageWasUser, prepareToolsAndToolChoice, processToolCalls } from "./utils";
|
14
14
|
|
15
15
|
type WorkersAIChatConfig = {
|
16
16
|
provider: string;
|
@@ -138,9 +138,7 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 {
|
|
138
138
|
const { gateway, safePrompt, ...passthroughOptions } = this.settings;
|
139
139
|
|
140
140
|
// Extract image from messages if present
|
141
|
-
const { messages, images } = convertToWorkersAIChatMessages(
|
142
|
-
options.prompt,
|
143
|
-
);
|
141
|
+
const { messages, images } = convertToWorkersAIChatMessages(options.prompt);
|
144
142
|
|
145
143
|
// TODO: support for multiple images
|
146
144
|
if (images.length !== 0 && images.length !== 1) {
|
@@ -189,9 +187,7 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 {
|
|
189
187
|
const { args, warnings } = this.getArgs(options);
|
190
188
|
|
191
189
|
// Extract image from messages if present
|
192
|
-
const { messages, images } = convertToWorkersAIChatMessages(
|
193
|
-
options.prompt,
|
194
|
-
);
|
190
|
+
const { messages, images } = convertToWorkersAIChatMessages(options.prompt);
|
195
191
|
|
196
192
|
// [1] When the latest message is not a tool response, we use the regular generate function
|
197
193
|
// and simulate it as a streamed response in order to satisfy the AI SDK's interface for
|
@@ -265,129 +261,10 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 {
|
|
265
261
|
throw new Error("This shouldn't happen");
|
266
262
|
}
|
267
263
|
|
268
|
-
const chunkEvent = events(new Response(response));
|
269
|
-
let usage = { promptTokens: 0, completionTokens: 0 };
|
270
|
-
|
271
264
|
return {
|
272
|
-
stream: new
|
273
|
-
async start(controller) {
|
274
|
-
for await (const event of chunkEvent) {
|
275
|
-
if (!event.data) {
|
276
|
-
continue;
|
277
|
-
}
|
278
|
-
if (event.data === "[DONE]") {
|
279
|
-
break;
|
280
|
-
}
|
281
|
-
const chunk = JSON.parse(event.data);
|
282
|
-
if (chunk.usage) {
|
283
|
-
usage = mapWorkersAIUsage(chunk);
|
284
|
-
}
|
285
|
-
chunk.response?.length &&
|
286
|
-
controller.enqueue({
|
287
|
-
type: "text-delta",
|
288
|
-
textDelta: chunk.response,
|
289
|
-
});
|
290
|
-
}
|
291
|
-
controller.enqueue({
|
292
|
-
type: "finish",
|
293
|
-
finishReason: "stop",
|
294
|
-
usage: usage,
|
295
|
-
});
|
296
|
-
controller.close();
|
297
|
-
},
|
298
|
-
}),
|
265
|
+
stream: getMappedStream(new Response(response)),
|
299
266
|
rawCall: { rawPrompt: messages, rawSettings: args },
|
300
267
|
warnings,
|
301
268
|
};
|
302
269
|
}
|
303
270
|
}
|
304
|
-
|
305
|
-
function processToolCalls(output: any) {
|
306
|
-
// Check for OpenAI format tool calls first
|
307
|
-
if (output.tool_calls && Array.isArray(output.tool_calls)) {
|
308
|
-
return output.tool_calls.map((toolCall: any) => {
|
309
|
-
// Handle new format
|
310
|
-
if (toolCall.function && toolCall.id) {
|
311
|
-
return {
|
312
|
-
toolCallType: "function",
|
313
|
-
toolCallId: toolCall.id,
|
314
|
-
toolName: toolCall.function.name,
|
315
|
-
args:
|
316
|
-
typeof toolCall.function.arguments === "string"
|
317
|
-
? toolCall.function.arguments
|
318
|
-
: JSON.stringify(toolCall.function.arguments || {}),
|
319
|
-
};
|
320
|
-
}
|
321
|
-
return {
|
322
|
-
toolCallType: "function",
|
323
|
-
toolCallId: toolCall.name,
|
324
|
-
toolName: toolCall.name,
|
325
|
-
args:
|
326
|
-
typeof toolCall.arguments === "string"
|
327
|
-
? toolCall.arguments
|
328
|
-
: JSON.stringify(toolCall.arguments || {}),
|
329
|
-
};
|
330
|
-
});
|
331
|
-
}
|
332
|
-
|
333
|
-
return [];
|
334
|
-
}
|
335
|
-
|
336
|
-
function prepareToolsAndToolChoice(
|
337
|
-
mode: Parameters<LanguageModelV1["doGenerate"]>[0]["mode"] & {
|
338
|
-
type: "regular";
|
339
|
-
},
|
340
|
-
) {
|
341
|
-
// when the tools array is empty, change it to undefined to prevent errors:
|
342
|
-
const tools = mode.tools?.length ? mode.tools : undefined;
|
343
|
-
|
344
|
-
if (tools == null) {
|
345
|
-
return { tools: undefined, tool_choice: undefined };
|
346
|
-
}
|
347
|
-
|
348
|
-
const mappedTools = tools.map((tool) => ({
|
349
|
-
type: "function",
|
350
|
-
function: {
|
351
|
-
name: tool.name,
|
352
|
-
// @ts-expect-error - description is not a property of tool
|
353
|
-
description: tool.description,
|
354
|
-
// @ts-expect-error - parameters is not a property of tool
|
355
|
-
parameters: tool.parameters,
|
356
|
-
},
|
357
|
-
}));
|
358
|
-
|
359
|
-
const toolChoice = mode.toolChoice;
|
360
|
-
|
361
|
-
if (toolChoice == null) {
|
362
|
-
return { tools: mappedTools, tool_choice: undefined };
|
363
|
-
}
|
364
|
-
|
365
|
-
const type = toolChoice.type;
|
366
|
-
|
367
|
-
switch (type) {
|
368
|
-
case "auto":
|
369
|
-
return { tools: mappedTools, tool_choice: type };
|
370
|
-
case "none":
|
371
|
-
return { tools: mappedTools, tool_choice: type };
|
372
|
-
case "required":
|
373
|
-
return { tools: mappedTools, tool_choice: "any" };
|
374
|
-
|
375
|
-
// workersAI does not support tool mode directly,
|
376
|
-
// so we filter the tools and force the tool choice through 'any'
|
377
|
-
case "tool":
|
378
|
-
return {
|
379
|
-
tools: mappedTools.filter(
|
380
|
-
(tool) => tool.function.name === toolChoice.toolName,
|
381
|
-
),
|
382
|
-
tool_choice: "any",
|
383
|
-
};
|
384
|
-
default: {
|
385
|
-
const exhaustiveCheck = type satisfies never;
|
386
|
-
throw new Error(`Unsupported tool choice type: ${exhaustiveCheck}`);
|
387
|
-
}
|
388
|
-
}
|
389
|
-
}
|
390
|
-
|
391
|
-
function lastMessageWasUser(messages: WorkersAIChatPrompt) {
|
392
|
-
return messages.length > 0 && messages[messages.length - 1].role === "user";
|
393
|
-
}
|
package/src/workersai-models.ts
CHANGED
@@ -11,4 +11,9 @@ export type TextGenerationModels = Exclude<
|
|
11
11
|
*/
|
12
12
|
export type ImageGenerationModels = value2key<AiModels, BaseAiTextToImage>;
|
13
13
|
|
14
|
+
/**
|
15
|
+
* The names of the BaseAiTextToEmbeddings models.
|
16
|
+
*/
|
17
|
+
export type EmbeddingModels = value2key<AiModels, BaseAiTextEmbeddings>;
|
18
|
+
|
14
19
|
type value2key<T, V> = { [K in keyof T]: T[K] extends V ? K : never }[keyof T];
|