workers-ai-provider 0.2.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -2
- package/dist/index.d.ts +30 -5
- package/dist/index.js +347 -52
- package/dist/index.js.map +1 -1
- package/package.json +40 -44
- package/src/convert-to-workersai-chat-messages.ts +94 -97
- package/src/index.ts +84 -82
- package/src/map-workersai-finish-reason.ts +12 -12
- package/src/map-workersai-usage.ts +8 -4
- package/src/utils.ts +89 -52
- package/src/workersai-chat-language-model.ts +313 -325
- package/src/workersai-chat-prompt.ts +18 -18
- package/src/workersai-chat-settings.ts +11 -11
- package/src/workersai-error.ts +10 -13
- package/src/workersai-image-config.ts +5 -0
- package/src/workersai-image-model.ts +114 -0
- package/src/workersai-image-settings.ts +3 -0
- package/src/workersai-models.ts +7 -2
@@ -1,10 +1,9 @@
|
|
1
1
|
import {
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
2
|
+
type LanguageModelV1,
|
3
|
+
type LanguageModelV1CallWarning,
|
4
|
+
type LanguageModelV1StreamPart,
|
5
|
+
UnsupportedFunctionalityError,
|
6
6
|
} from "@ai-sdk/provider";
|
7
|
-
import { z } from "zod";
|
8
7
|
import { convertToWorkersAIChatMessages } from "./convert-to-workersai-chat-messages";
|
9
8
|
import type { WorkersAIChatSettings } from "./workersai-chat-settings";
|
10
9
|
import type { TextGenerationModels } from "./workersai-models";
|
@@ -14,334 +13,323 @@ import { mapWorkersAIUsage } from "./map-workersai-usage";
|
|
14
13
|
import type { WorkersAIChatPrompt } from "./workersai-chat-prompt";
|
15
14
|
|
16
15
|
type WorkersAIChatConfig = {
|
17
|
-
|
18
|
-
|
19
|
-
|
16
|
+
provider: string;
|
17
|
+
binding: Ai;
|
18
|
+
gateway?: GatewayOptions;
|
20
19
|
};
|
21
20
|
|
22
21
|
export class WorkersAIChatLanguageModel implements LanguageModelV1 {
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
22
|
+
readonly specificationVersion = "v1";
|
23
|
+
readonly defaultObjectGenerationMode = "json";
|
24
|
+
|
25
|
+
readonly modelId: TextGenerationModels;
|
26
|
+
readonly settings: WorkersAIChatSettings;
|
27
|
+
|
28
|
+
private readonly config: WorkersAIChatConfig;
|
29
|
+
|
30
|
+
constructor(
|
31
|
+
modelId: TextGenerationModels,
|
32
|
+
settings: WorkersAIChatSettings,
|
33
|
+
config: WorkersAIChatConfig,
|
34
|
+
) {
|
35
|
+
this.modelId = modelId;
|
36
|
+
this.settings = settings;
|
37
|
+
this.config = config;
|
38
|
+
}
|
39
|
+
|
40
|
+
get provider(): string {
|
41
|
+
return this.config.provider;
|
42
|
+
}
|
43
|
+
|
44
|
+
private getArgs({
|
45
|
+
mode,
|
46
|
+
prompt,
|
47
|
+
maxTokens,
|
48
|
+
temperature,
|
49
|
+
topP,
|
50
|
+
frequencyPenalty,
|
51
|
+
presencePenalty,
|
52
|
+
seed,
|
53
|
+
}: Parameters<LanguageModelV1["doGenerate"]>[0]) {
|
54
|
+
const type = mode.type;
|
55
|
+
|
56
|
+
const warnings: LanguageModelV1CallWarning[] = [];
|
57
|
+
|
58
|
+
if (frequencyPenalty != null) {
|
59
|
+
warnings.push({
|
60
|
+
type: "unsupported-setting",
|
61
|
+
setting: "frequencyPenalty",
|
62
|
+
});
|
63
|
+
}
|
64
|
+
|
65
|
+
if (presencePenalty != null) {
|
66
|
+
warnings.push({
|
67
|
+
type: "unsupported-setting",
|
68
|
+
setting: "presencePenalty",
|
69
|
+
});
|
70
|
+
}
|
71
|
+
|
72
|
+
const baseArgs = {
|
73
|
+
// model id:
|
74
|
+
model: this.modelId,
|
75
|
+
|
76
|
+
// model specific settings:
|
77
|
+
safe_prompt: this.settings.safePrompt,
|
78
|
+
|
79
|
+
// standardized settings:
|
80
|
+
max_tokens: maxTokens,
|
81
|
+
temperature,
|
82
|
+
top_p: topP,
|
83
|
+
random_seed: seed,
|
84
|
+
|
85
|
+
// messages:
|
86
|
+
messages: convertToWorkersAIChatMessages(prompt),
|
87
|
+
};
|
88
|
+
|
89
|
+
switch (type) {
|
90
|
+
case "regular": {
|
91
|
+
return {
|
92
|
+
args: { ...baseArgs, ...prepareToolsAndToolChoice(mode) },
|
93
|
+
warnings,
|
94
|
+
};
|
95
|
+
}
|
96
|
+
|
97
|
+
case "object-json": {
|
98
|
+
return {
|
99
|
+
args: {
|
100
|
+
...baseArgs,
|
101
|
+
response_format: {
|
102
|
+
type: "json_schema",
|
103
|
+
json_schema: mode.schema,
|
104
|
+
},
|
105
|
+
tools: undefined,
|
106
|
+
},
|
107
|
+
warnings,
|
108
|
+
};
|
109
|
+
}
|
110
|
+
|
111
|
+
case "object-tool": {
|
112
|
+
return {
|
113
|
+
args: {
|
114
|
+
...baseArgs,
|
115
|
+
tool_choice: "any",
|
116
|
+
tools: [{ type: "function", function: mode.tool }],
|
117
|
+
},
|
118
|
+
warnings,
|
119
|
+
};
|
120
|
+
}
|
121
|
+
|
122
|
+
// @ts-expect-error - this is unreachable code
|
123
|
+
// TODO: fixme
|
124
|
+
case "object-grammar": {
|
125
|
+
throw new UnsupportedFunctionalityError({
|
126
|
+
functionality: "object-grammar mode",
|
127
|
+
});
|
128
|
+
}
|
129
|
+
|
130
|
+
default: {
|
131
|
+
const exhaustiveCheck = type satisfies never;
|
132
|
+
throw new Error(`Unsupported type: ${exhaustiveCheck}`);
|
133
|
+
}
|
134
|
+
}
|
135
|
+
}
|
136
|
+
|
137
|
+
async doGenerate(
|
138
|
+
options: Parameters<LanguageModelV1["doGenerate"]>[0],
|
139
|
+
): Promise<Awaited<ReturnType<LanguageModelV1["doGenerate"]>>> {
|
140
|
+
const { args, warnings } = this.getArgs(options);
|
141
|
+
|
142
|
+
const output = await this.config.binding.run(
|
143
|
+
args.model,
|
144
|
+
{
|
145
|
+
messages: args.messages,
|
146
|
+
max_tokens: args.max_tokens,
|
147
|
+
temperature: args.temperature,
|
148
|
+
tools: args.tools,
|
149
|
+
top_p: args.top_p,
|
150
|
+
// @ts-expect-error response_format not yet added to types
|
151
|
+
response_format: args.response_format,
|
152
|
+
},
|
153
|
+
{ gateway: this.config.gateway ?? this.settings.gateway },
|
154
|
+
);
|
155
|
+
|
156
|
+
if (output instanceof ReadableStream) {
|
157
|
+
throw new Error("This shouldn't happen");
|
158
|
+
}
|
159
|
+
|
160
|
+
return {
|
161
|
+
text:
|
162
|
+
typeof output.response === "object" && output.response !== null
|
163
|
+
? JSON.stringify(output.response) // ai-sdk expects a string here
|
164
|
+
: output.response,
|
165
|
+
toolCalls: output.tool_calls?.map((toolCall) => ({
|
166
|
+
toolCallType: "function",
|
167
|
+
toolCallId: toolCall.name,
|
168
|
+
toolName: toolCall.name,
|
169
|
+
args: JSON.stringify(toolCall.arguments || {}),
|
170
|
+
})),
|
171
|
+
finishReason: "stop", // TODO: mapWorkersAIFinishReason(response.finish_reason),
|
172
|
+
rawCall: { rawPrompt: args.messages, rawSettings: args },
|
173
|
+
usage: mapWorkersAIUsage(output),
|
174
|
+
warnings,
|
175
|
+
};
|
176
|
+
}
|
177
|
+
|
178
|
+
async doStream(
|
179
|
+
options: Parameters<LanguageModelV1["doStream"]>[0],
|
180
|
+
): Promise<Awaited<ReturnType<LanguageModelV1["doStream"]>>> {
|
181
|
+
const { args, warnings } = this.getArgs(options);
|
182
|
+
|
183
|
+
// [1] When the latest message is not a tool response, we use the regular generate function
|
184
|
+
// and simulate it as a streamed response in order to satisfy the AI SDK's interface for
|
185
|
+
// doStream...
|
186
|
+
if (args.tools?.length && lastMessageWasUser(args.messages)) {
|
187
|
+
const response = await this.doGenerate(options);
|
188
|
+
|
189
|
+
if (response instanceof ReadableStream) {
|
190
|
+
throw new Error("This shouldn't happen");
|
191
|
+
}
|
192
|
+
|
193
|
+
return {
|
194
|
+
stream: new ReadableStream<LanguageModelV1StreamPart>({
|
195
|
+
async start(controller) {
|
196
|
+
if (response.text) {
|
197
|
+
controller.enqueue({
|
198
|
+
type: "text-delta",
|
199
|
+
textDelta: response.text,
|
200
|
+
});
|
201
|
+
}
|
202
|
+
if (response.toolCalls) {
|
203
|
+
for (const toolCall of response.toolCalls) {
|
204
|
+
controller.enqueue({
|
205
|
+
type: "tool-call",
|
206
|
+
...toolCall,
|
207
|
+
});
|
208
|
+
}
|
209
|
+
}
|
210
|
+
controller.enqueue({
|
211
|
+
type: "finish",
|
212
|
+
finishReason: "stop",
|
213
|
+
usage: response.usage,
|
214
|
+
});
|
215
|
+
controller.close();
|
216
|
+
},
|
217
|
+
}),
|
218
|
+
rawCall: { rawPrompt: args.messages, rawSettings: args },
|
219
|
+
warnings,
|
220
|
+
};
|
221
|
+
}
|
222
|
+
|
223
|
+
// [2] ...otherwise, we just proceed as normal and stream the response directly from the remote model.
|
224
|
+
const response = await this.config.binding.run(
|
225
|
+
args.model,
|
226
|
+
{
|
227
|
+
messages: args.messages,
|
228
|
+
max_tokens: args.max_tokens,
|
229
|
+
stream: true,
|
230
|
+
temperature: args.temperature,
|
231
|
+
tools: args.tools,
|
232
|
+
top_p: args.top_p,
|
233
|
+
// @ts-expect-error response_format not yet added to types
|
234
|
+
response_format: args.response_format,
|
235
|
+
},
|
236
|
+
{ gateway: this.config.gateway ?? this.settings.gateway },
|
237
|
+
);
|
238
|
+
|
239
|
+
if (!(response instanceof ReadableStream)) {
|
240
|
+
throw new Error("This shouldn't happen");
|
241
|
+
}
|
242
|
+
|
243
|
+
const chunkEvent = events(new Response(response));
|
244
|
+
let usage = { promptTokens: 0, completionTokens: 0 };
|
245
|
+
|
246
|
+
return {
|
247
|
+
stream: new ReadableStream<LanguageModelV1StreamPart>({
|
248
|
+
async start(controller) {
|
249
|
+
for await (const event of chunkEvent) {
|
250
|
+
if (!event.data) {
|
251
|
+
continue;
|
252
|
+
}
|
253
|
+
if (event.data === "[DONE]") {
|
254
|
+
break;
|
255
|
+
}
|
256
|
+
const chunk = JSON.parse(event.data);
|
257
|
+
if (chunk.usage) {
|
258
|
+
usage = mapWorkersAIUsage(chunk);
|
259
|
+
}
|
260
|
+
chunk.response.length &&
|
261
|
+
controller.enqueue({
|
262
|
+
type: "text-delta",
|
263
|
+
textDelta: chunk.response,
|
264
|
+
});
|
265
|
+
}
|
266
|
+
controller.enqueue({
|
267
|
+
type: "finish",
|
268
|
+
finishReason: "stop",
|
269
|
+
usage: usage,
|
270
|
+
});
|
271
|
+
controller.close();
|
272
|
+
},
|
273
|
+
}),
|
274
|
+
rawCall: { rawPrompt: args.messages, rawSettings: args },
|
275
|
+
warnings,
|
276
|
+
};
|
277
|
+
}
|
279
278
|
}
|
280
|
-
// limited version of the schema, focussed on what is needed for the implementation
|
281
|
-
// this approach limits breakages when the API changes and increases efficiency
|
282
|
-
const workersAIChatResponseSchema = z.object({
|
283
|
-
response: z.string(),
|
284
|
-
});
|
285
|
-
|
286
|
-
// limited version of the schema, focussed on what is needed for the implementation
|
287
|
-
// this approach limits breakages when the API changes and increases efficiency
|
288
|
-
const workersAIChatChunkSchema = z.instanceof(Uint8Array);
|
289
279
|
|
290
280
|
function prepareToolsAndToolChoice(
|
291
|
-
|
292
|
-
|
293
|
-
|
281
|
+
mode: Parameters<LanguageModelV1["doGenerate"]>[0]["mode"] & {
|
282
|
+
type: "regular";
|
283
|
+
},
|
294
284
|
) {
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
}
|
342
|
-
}
|
285
|
+
// when the tools array is empty, change it to undefined to prevent errors:
|
286
|
+
const tools = mode.tools?.length ? mode.tools : undefined;
|
287
|
+
|
288
|
+
if (tools == null) {
|
289
|
+
return { tools: undefined, tool_choice: undefined };
|
290
|
+
}
|
291
|
+
|
292
|
+
const mappedTools = tools.map((tool) => ({
|
293
|
+
type: "function",
|
294
|
+
function: {
|
295
|
+
name: tool.name,
|
296
|
+
// @ts-expect-error - description is not a property of tool
|
297
|
+
description: tool.description,
|
298
|
+
// @ts-expect-error - parameters is not a property of tool
|
299
|
+
parameters: tool.parameters,
|
300
|
+
},
|
301
|
+
}));
|
302
|
+
|
303
|
+
const toolChoice = mode.toolChoice;
|
304
|
+
|
305
|
+
if (toolChoice == null) {
|
306
|
+
return { tools: mappedTools, tool_choice: undefined };
|
307
|
+
}
|
308
|
+
|
309
|
+
const type = toolChoice.type;
|
310
|
+
|
311
|
+
switch (type) {
|
312
|
+
case "auto":
|
313
|
+
return { tools: mappedTools, tool_choice: type };
|
314
|
+
case "none":
|
315
|
+
return { tools: mappedTools, tool_choice: type };
|
316
|
+
case "required":
|
317
|
+
return { tools: mappedTools, tool_choice: "any" };
|
318
|
+
|
319
|
+
// workersAI does not support tool mode directly,
|
320
|
+
// so we filter the tools and force the tool choice through 'any'
|
321
|
+
case "tool":
|
322
|
+
return {
|
323
|
+
tools: mappedTools.filter((tool) => tool.function.name === toolChoice.toolName),
|
324
|
+
tool_choice: "any",
|
325
|
+
};
|
326
|
+
default: {
|
327
|
+
const exhaustiveCheck = type satisfies never;
|
328
|
+
throw new Error(`Unsupported tool choice type: ${exhaustiveCheck}`);
|
329
|
+
}
|
330
|
+
}
|
343
331
|
}
|
344
332
|
|
345
333
|
function lastMessageWasUser(messages: WorkersAIChatPrompt) {
|
346
|
-
|
334
|
+
return messages.length > 0 && messages[messages.length - 1].role === "user";
|
347
335
|
}
|