workers-ai-provider 0.1.2 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +19 -0
- package/dist/index.d.ts +27 -18
- package/dist/index.js +83 -2
- package/dist/index.js.map +1 -1
- package/package.json +1 -1
- package/src/index.ts +57 -22
- package/src/utils.ts +64 -0
- package/src/workersai-chat-language-model.ts +48 -0
package/README.md
CHANGED
@@ -58,6 +58,25 @@ export default {
|
|
58
58
|
};
|
59
59
|
```
|
60
60
|
|
61
|
+
You can also use your Cloudflare credentials to create the provider, for example if you want to use Cloudflare AI outside of the Worker environment. For example, here is how you can use Cloudflare AI in a Node script:
|
62
|
+
|
63
|
+
```js
|
64
|
+
const workersai = createWorkersAI({
|
65
|
+
accountId: process.env.CLOUDFLARE_ACCOUNT_ID,
|
66
|
+
apiKey: process.env.CLOUDFLARE_API_KEY
|
67
|
+
});
|
68
|
+
|
69
|
+
const text = await streamText({
|
70
|
+
model: workersai("@cf/meta/llama-2-7b-chat-int8"),
|
71
|
+
messages: [
|
72
|
+
{
|
73
|
+
role: "user",
|
74
|
+
content: "Write an essay about hello world",
|
75
|
+
},
|
76
|
+
],
|
77
|
+
});
|
78
|
+
```
|
79
|
+
|
61
80
|
For more info, refer to the documentation of the [Vercel AI SDK](https://sdk.vercel.ai/).
|
62
81
|
|
63
82
|
### Credits
|
package/dist/index.d.ts
CHANGED
@@ -40,6 +40,32 @@ declare class WorkersAIChatLanguageModel implements LanguageModelV1 {
|
|
40
40
|
doStream(options: Parameters<LanguageModelV1["doStream"]>[0]): Promise<Awaited<ReturnType<LanguageModelV1["doStream"]>>>;
|
41
41
|
}
|
42
42
|
|
43
|
+
type WorkersAISettings = ({
|
44
|
+
/**
|
45
|
+
* Provide a Cloudflare AI binding.
|
46
|
+
*/
|
47
|
+
binding: Ai;
|
48
|
+
/**
|
49
|
+
* Credentials must be absent when a binding is given.
|
50
|
+
*/
|
51
|
+
accountId?: never;
|
52
|
+
apiKey?: never;
|
53
|
+
} | {
|
54
|
+
/**
|
55
|
+
* Provide Cloudflare API credentials directly. Must be used if a binding is not specified.
|
56
|
+
*/
|
57
|
+
accountId: string;
|
58
|
+
apiKey: string;
|
59
|
+
/**
|
60
|
+
* Both binding must be absent if credentials are used directly.
|
61
|
+
*/
|
62
|
+
binding?: never;
|
63
|
+
}) & {
|
64
|
+
/**
|
65
|
+
* Optionally specify a gateway.
|
66
|
+
*/
|
67
|
+
gateway?: GatewayOptions;
|
68
|
+
};
|
43
69
|
interface WorkersAI {
|
44
70
|
(modelId: TextGenerationModels, settings?: WorkersAIChatSettings): WorkersAIChatLanguageModel;
|
45
71
|
/**
|
@@ -47,26 +73,9 @@ interface WorkersAI {
|
|
47
73
|
**/
|
48
74
|
chat(modelId: TextGenerationModels, settings?: WorkersAIChatSettings): WorkersAIChatLanguageModel;
|
49
75
|
}
|
50
|
-
interface WorkersAISettings {
|
51
|
-
/**
|
52
|
-
* Provide an `env.AI` binding to use for the AI inference.
|
53
|
-
* You can set up an AI bindings in your Workers project
|
54
|
-
* by adding the following this to `wrangler.toml`:
|
55
|
-
|
56
|
-
```toml
|
57
|
-
[ai]
|
58
|
-
binding = "AI"
|
59
|
-
```
|
60
|
-
**/
|
61
|
-
binding: Ai;
|
62
|
-
/**
|
63
|
-
* Optionally set Cloudflare AI Gateway options.
|
64
|
-
*/
|
65
|
-
gateway?: GatewayOptions;
|
66
|
-
}
|
67
76
|
/**
|
68
77
|
* Create a Workers AI provider instance.
|
69
|
-
|
78
|
+
*/
|
70
79
|
declare function createWorkersAI(options: WorkersAISettings): WorkersAI;
|
71
80
|
|
72
81
|
export { type WorkersAI, type WorkersAISettings, createWorkersAI };
|
package/dist/index.js
CHANGED
@@ -240,6 +240,40 @@ var WorkersAIChatLanguageModel = class {
|
|
240
240
|
}
|
241
241
|
async doStream(options) {
|
242
242
|
const { args, warnings } = this.getArgs(options);
|
243
|
+
if (args.tools?.length && lastMessageWasUser(args.messages)) {
|
244
|
+
const response2 = await this.doGenerate(options);
|
245
|
+
if (response2 instanceof ReadableStream) {
|
246
|
+
throw new Error("This shouldn't happen");
|
247
|
+
}
|
248
|
+
return {
|
249
|
+
stream: new ReadableStream({
|
250
|
+
async start(controller) {
|
251
|
+
if (response2.text) {
|
252
|
+
controller.enqueue({
|
253
|
+
type: "text-delta",
|
254
|
+
textDelta: response2.text
|
255
|
+
});
|
256
|
+
}
|
257
|
+
if (response2.toolCalls) {
|
258
|
+
for (const toolCall of response2.toolCalls) {
|
259
|
+
controller.enqueue({
|
260
|
+
type: "tool-call",
|
261
|
+
...toolCall
|
262
|
+
});
|
263
|
+
}
|
264
|
+
}
|
265
|
+
controller.enqueue({
|
266
|
+
type: "finish",
|
267
|
+
finishReason: "stop",
|
268
|
+
usage: response2.usage
|
269
|
+
});
|
270
|
+
controller.close();
|
271
|
+
}
|
272
|
+
}),
|
273
|
+
rawCall: { rawPrompt: args.messages, rawSettings: args },
|
274
|
+
warnings
|
275
|
+
};
|
276
|
+
}
|
243
277
|
const response = await this.config.binding.run(
|
244
278
|
args.model,
|
245
279
|
{
|
@@ -248,7 +282,9 @@ var WorkersAIChatLanguageModel = class {
|
|
248
282
|
stream: true,
|
249
283
|
temperature: args.temperature,
|
250
284
|
tools: args.tools,
|
251
|
-
top_p: args.top_p
|
285
|
+
top_p: args.top_p,
|
286
|
+
// @ts-expect-error response_format not yet added to types
|
287
|
+
response_format: args.response_format
|
252
288
|
},
|
253
289
|
{ gateway: this.config.gateway ?? this.settings.gateway }
|
254
290
|
);
|
@@ -335,12 +371,57 @@ function prepareToolsAndToolChoice(mode) {
|
|
335
371
|
}
|
336
372
|
}
|
337
373
|
}
|
374
|
+
function lastMessageWasUser(messages) {
|
375
|
+
return messages.length > 0 && messages[messages.length - 1].role === "user";
|
376
|
+
}
|
377
|
+
|
378
|
+
// src/utils.ts
|
379
|
+
function createRun(accountId, apiKey) {
|
380
|
+
return async (model, inputs, options) => {
|
381
|
+
const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}`;
|
382
|
+
const body = JSON.stringify(inputs);
|
383
|
+
const headers = {
|
384
|
+
"Content-Type": "application/json",
|
385
|
+
Authorization: `Bearer ${apiKey}`
|
386
|
+
};
|
387
|
+
const response = await fetch(url, {
|
388
|
+
method: "POST",
|
389
|
+
headers,
|
390
|
+
body
|
391
|
+
});
|
392
|
+
if (options?.returnRawResponse) {
|
393
|
+
return response;
|
394
|
+
}
|
395
|
+
if (inputs.stream === true) {
|
396
|
+
if (response.body) {
|
397
|
+
return response.body;
|
398
|
+
}
|
399
|
+
throw new Error("No readable body available for streaming.");
|
400
|
+
}
|
401
|
+
const data = await response.json();
|
402
|
+
return data.result;
|
403
|
+
};
|
404
|
+
}
|
338
405
|
|
339
406
|
// src/index.ts
|
340
407
|
function createWorkersAI(options) {
|
408
|
+
let binding;
|
409
|
+
if (options.binding) {
|
410
|
+
binding = options.binding;
|
411
|
+
} else {
|
412
|
+
const { accountId, apiKey } = options;
|
413
|
+
binding = {
|
414
|
+
run: createRun(accountId, apiKey)
|
415
|
+
};
|
416
|
+
}
|
417
|
+
if (!binding) {
|
418
|
+
throw new Error(
|
419
|
+
"Either a binding or credentials must be provided."
|
420
|
+
);
|
421
|
+
}
|
341
422
|
const createChatModel = (modelId, settings = {}) => new WorkersAIChatLanguageModel(modelId, settings, {
|
342
423
|
provider: "workersai.chat",
|
343
|
-
binding
|
424
|
+
binding,
|
344
425
|
gateway: options.gateway
|
345
426
|
});
|
346
427
|
const provider = function(modelId, settings) {
|
package/dist/index.js.map
CHANGED
@@ -1 +1 @@
|
|
1
|
-
{"version":3,"sources":["../src/workersai-chat-language-model.ts","../src/convert-to-workersai-chat-messages.ts","../src/map-workersai-usage.ts","../src/index.ts"],"sourcesContent":["import {\n type LanguageModelV1,\n type LanguageModelV1CallWarning,\n type LanguageModelV1StreamPart,\n UnsupportedFunctionalityError,\n} from \"@ai-sdk/provider\";\nimport { z } from \"zod\";\nimport { convertToWorkersAIChatMessages } from \"./convert-to-workersai-chat-messages\";\nimport type { WorkersAIChatSettings } from \"./workersai-chat-settings\";\nimport type { TextGenerationModels } from \"./workersai-models\";\n\nimport { events } from \"fetch-event-stream\";\nimport { mapWorkersAIUsage } from \"./map-workersai-usage\";\n\ntype WorkersAIChatConfig = {\n provider: string;\n binding: Ai;\n gateway?: GatewayOptions;\n};\n\nexport class WorkersAIChatLanguageModel implements LanguageModelV1 {\n readonly specificationVersion = \"v1\";\n readonly defaultObjectGenerationMode = \"json\";\n\n readonly modelId: TextGenerationModels;\n readonly settings: WorkersAIChatSettings;\n\n private readonly config: WorkersAIChatConfig;\n\n constructor(\n modelId: TextGenerationModels,\n settings: WorkersAIChatSettings,\n config: WorkersAIChatConfig\n ) {\n this.modelId = modelId;\n this.settings = settings;\n this.config = config;\n }\n\n get provider(): string {\n return this.config.provider;\n }\n\n private getArgs({\n mode,\n prompt,\n maxTokens,\n temperature,\n topP,\n frequencyPenalty,\n presencePenalty,\n seed,\n }: Parameters<LanguageModelV1[\"doGenerate\"]>[0]) {\n const type = mode.type;\n\n const warnings: LanguageModelV1CallWarning[] = [];\n\n if (frequencyPenalty != null) {\n warnings.push({\n type: \"unsupported-setting\",\n setting: \"frequencyPenalty\",\n });\n }\n\n if (presencePenalty != null) {\n warnings.push({\n type: \"unsupported-setting\",\n setting: \"presencePenalty\",\n });\n }\n\n const baseArgs = {\n // model id:\n model: this.modelId,\n\n // model specific settings:\n safe_prompt: this.settings.safePrompt,\n\n // standardized settings:\n max_tokens: maxTokens,\n temperature,\n top_p: topP,\n random_seed: seed,\n\n // messages:\n messages: convertToWorkersAIChatMessages(prompt),\n };\n\n switch (type) {\n case \"regular\": {\n return {\n args: { ...baseArgs, ...prepareToolsAndToolChoice(mode) },\n warnings,\n };\n }\n\n case \"object-json\": {\n return {\n args: {\n ...baseArgs,\n response_format: {\n type: \"json_schema\",\n json_schema: mode.schema,\n },\n tools: undefined,\n },\n warnings,\n };\n }\n\n case \"object-tool\": {\n return {\n args: {\n ...baseArgs,\n tool_choice: \"any\",\n tools: [{ type: \"function\", function: mode.tool }],\n },\n warnings,\n };\n }\n\n // @ts-expect-error - this is unreachable code\n // TODO: fixme\n case \"object-grammar\": {\n throw new UnsupportedFunctionalityError({\n functionality: \"object-grammar mode\",\n });\n }\n\n default: {\n const exhaustiveCheck = type satisfies never;\n throw new Error(`Unsupported type: ${exhaustiveCheck}`);\n }\n }\n }\n\n async doGenerate(\n options: Parameters<LanguageModelV1[\"doGenerate\"]>[0]\n ): Promise<Awaited<ReturnType<LanguageModelV1[\"doGenerate\"]>>> {\n const { args, warnings } = this.getArgs(options);\n\n const output = await this.config.binding.run(\n args.model,\n {\n messages: args.messages,\n max_tokens: args.max_tokens,\n temperature: args.temperature,\n tools: args.tools,\n top_p: args.top_p,\n // @ts-expect-error response_format not yet added to types\n response_format: args.response_format,\n },\n { gateway: this.config.gateway ?? this.settings.gateway }\n );\n\n if (output instanceof ReadableStream) {\n throw new Error(\"This shouldn't happen\");\n }\n\n return {\n text:\n typeof output.response === \"object\" && output.response !== null\n ? JSON.stringify(output.response) // ai-sdk expects a string here\n : output.response,\n toolCalls: output.tool_calls?.map((toolCall) => ({\n toolCallType: \"function\",\n toolCallId: toolCall.name,\n toolName: toolCall.name,\n args: JSON.stringify(toolCall.arguments || {}),\n })),\n finishReason: \"stop\", // TODO: mapWorkersAIFinishReason(response.finish_reason),\n rawCall: { rawPrompt: args.messages, rawSettings: args },\n usage: mapWorkersAIUsage(output),\n warnings,\n };\n }\n\n async doStream(\n options: Parameters<LanguageModelV1[\"doStream\"]>[0]\n ): Promise<Awaited<ReturnType<LanguageModelV1[\"doStream\"]>>> {\n const { args, warnings } = this.getArgs(options);\n\n const response = await this.config.binding.run(\n args.model,\n {\n messages: args.messages,\n max_tokens: args.max_tokens,\n stream: true,\n temperature: args.temperature,\n tools: args.tools,\n top_p: args.top_p,\n },\n { gateway: this.config.gateway ?? this.settings.gateway }\n );\n\n if (!(response instanceof ReadableStream)) {\n throw new Error(\"This shouldn't happen\");\n }\n\n const chunkEvent = events(new Response(response));\n let usage = { promptTokens: 0, completionTokens: 0 };\n\n return {\n stream: new ReadableStream<LanguageModelV1StreamPart>({\n async start(controller) {\n for await (const event of chunkEvent) {\n if (!event.data) {\n continue;\n }\n if (event.data === \"[DONE]\") {\n break;\n }\n const chunk = JSON.parse(event.data);\n if (chunk.usage) {\n usage = mapWorkersAIUsage(chunk);\n }\n chunk.response.length &&\n controller.enqueue({\n type: \"text-delta\",\n textDelta: chunk.response,\n });\n }\n controller.enqueue({\n type: \"finish\",\n finishReason: \"stop\",\n usage: usage,\n });\n controller.close();\n },\n }),\n rawCall: { rawPrompt: args.messages, rawSettings: args },\n warnings,\n };\n }\n}\n// limited version of the schema, focussed on what is needed for the implementation\n// this approach limits breakages when the API changes and increases efficiency\nconst workersAIChatResponseSchema = z.object({\n response: z.string(),\n});\n\n// limited version of the schema, focussed on what is needed for the implementation\n// this approach limits breakages when the API changes and increases efficiency\nconst workersAIChatChunkSchema = z.instanceof(Uint8Array);\n\nfunction prepareToolsAndToolChoice(\n mode: Parameters<LanguageModelV1[\"doGenerate\"]>[0][\"mode\"] & {\n type: \"regular\";\n }\n) {\n // when the tools array is empty, change it to undefined to prevent errors:\n const tools = mode.tools?.length ? mode.tools : undefined;\n\n if (tools == null) {\n return { tools: undefined, tool_choice: undefined };\n }\n\n const mappedTools = tools.map((tool) => ({\n type: \"function\",\n function: {\n name: tool.name,\n // @ts-expect-error - description is not a property of tool\n description: tool.description,\n // @ts-expect-error - parameters is not a property of tool\n parameters: tool.parameters,\n },\n }));\n\n const toolChoice = mode.toolChoice;\n\n if (toolChoice == null) {\n return { tools: mappedTools, tool_choice: undefined };\n }\n\n const type = toolChoice.type;\n\n switch (type) {\n case \"auto\":\n return { tools: mappedTools, tool_choice: type };\n case \"none\":\n return { tools: mappedTools, tool_choice: type };\n case \"required\":\n return { tools: mappedTools, tool_choice: \"any\" };\n\n // workersAI does not support tool mode directly,\n // so we filter the tools and force the tool choice through 'any'\n case \"tool\":\n return {\n tools: mappedTools.filter(\n (tool) => tool.function.name === toolChoice.toolName\n ),\n tool_choice: \"any\",\n };\n default: {\n const exhaustiveCheck = type satisfies never;\n throw new Error(`Unsupported tool choice type: ${exhaustiveCheck}`);\n }\n }\n}\n","import {\n type LanguageModelV1Prompt,\n UnsupportedFunctionalityError,\n} from \"@ai-sdk/provider\";\nimport type { WorkersAIChatPrompt } from \"./workersai-chat-prompt\";\n\n// TODO\nexport function convertToWorkersAIChatMessages(\n prompt: LanguageModelV1Prompt\n): WorkersAIChatPrompt {\n const messages: WorkersAIChatPrompt = [];\n\n for (const { role, content } of prompt) {\n switch (role) {\n case \"system\": {\n messages.push({ role: \"system\", content });\n break;\n }\n\n case \"user\": {\n messages.push({\n role: \"user\",\n content: content\n .map((part) => {\n switch (part.type) {\n case \"text\": {\n return part.text;\n }\n case \"image\": {\n throw new UnsupportedFunctionalityError({\n functionality: \"image-part\",\n });\n }\n }\n })\n .join(\"\"),\n });\n break;\n }\n\n case \"assistant\": {\n let text = \"\";\n const toolCalls: Array<{\n id: string;\n type: \"function\";\n function: { name: string; arguments: string };\n }> = [];\n\n for (const part of content) {\n switch (part.type) {\n case \"text\": {\n text += part.text;\n break;\n }\n case \"tool-call\": {\n text = JSON.stringify({ name: part.toolName, parameters: part.args });\n\n toolCalls.push({\n id: part.toolCallId,\n type: \"function\",\n function: {\n name: part.toolName,\n arguments: JSON.stringify(part.args),\n },\n });\n break;\n }\n default: {\n const exhaustiveCheck = part satisfies never;\n throw new Error(`Unsupported part: ${exhaustiveCheck}`);\n }\n }\n }\n\n messages.push({\n role: \"assistant\",\n content: text,\n tool_calls:\n toolCalls.length > 0\n ? toolCalls.map(({ function: { name, arguments: args } }) => ({\n id: \"null\",\n type: \"function\",\n function: { name, arguments: args },\n }))\n : undefined,\n });\n\n break;\n }\n case \"tool\": {\n for (const toolResponse of content) {\n messages.push({\n role: \"tool\",\n name: toolResponse.toolName,\n content: JSON.stringify(toolResponse.result),\n });\n }\n break;\n }\n default: {\n const exhaustiveCheck = role satisfies never;\n throw new Error(`Unsupported role: ${exhaustiveCheck}`);\n }\n }\n }\n\n return messages;\n}\n","export function mapWorkersAIUsage(output: AiTextGenerationOutput | AiTextToImageOutput) {\n\tconst usage = (output as { usage: { prompt_tokens: number, completion_tokens: number} }).usage ?? {\n\t\tprompt_tokens: 0,\n\t\tcompletion_tokens: 0,\n\t};\n\n\treturn {\n\t\tpromptTokens: usage.prompt_tokens,\n\t\tcompletionTokens: usage.completion_tokens\n\t}\n}\n","import { WorkersAIChatLanguageModel } from \"./workersai-chat-language-model\";\nimport type { WorkersAIChatSettings } from \"./workersai-chat-settings\";\nimport type { TextGenerationModels } from \"./workersai-models\";\n\nexport interface WorkersAI {\n (\n modelId: TextGenerationModels,\n settings?: WorkersAIChatSettings\n ): WorkersAIChatLanguageModel;\n\n /**\n * Creates a model for text generation.\n **/\n chat(\n modelId: TextGenerationModels,\n settings?: WorkersAIChatSettings\n ): WorkersAIChatLanguageModel;\n}\n\nexport interface WorkersAISettings {\n /**\n * Provide an `env.AI` binding to use for the AI inference.\n * You can set up an AI bindings in your Workers project\n * by adding the following this to `wrangler.toml`:\n\n ```toml\n[ai]\nbinding = \"AI\"\n ```\n **/\n binding: Ai;\n /**\n * Optionally set Cloudflare AI Gateway options.\n */\n gateway?: GatewayOptions;\n}\n\n/**\n * Create a Workers AI provider instance.\n **/\nexport function createWorkersAI(options: WorkersAISettings): WorkersAI {\n const createChatModel = (\n modelId: TextGenerationModels,\n settings: WorkersAIChatSettings = {}\n ) =>\n new WorkersAIChatLanguageModel(modelId, settings, {\n provider: \"workersai.chat\",\n binding: options.binding,\n gateway: options.gateway,\n });\n\n const provider = function (\n modelId: TextGenerationModels,\n settings?: WorkersAIChatSettings\n ) {\n if (new.target) {\n throw new Error(\n \"The WorkersAI model function cannot be called with the new keyword.\"\n );\n }\n\n return createChatModel(modelId, settings);\n };\n\n provider.chat = createChatModel;\n\n return provider;\n}\n"],"mappings":";;;;;AAAA;AAAA,EAIE,iCAAAA;AAAA,OACK;AACP,SAAS,SAAS;;;ACNlB;AAAA,EAEE;AAAA,OACK;AAIA,SAAS,+BACd,QACqB;AACrB,QAAM,WAAgC,CAAC;AAEvC,aAAW,EAAE,MAAM,QAAQ,KAAK,QAAQ;AACtC,YAAQ,MAAM;AAAA,MACZ,KAAK,UAAU;AACb,iBAAS,KAAK,EAAE,MAAM,UAAU,QAAQ,CAAC;AACzC;AAAA,MACF;AAAA,MAEA,KAAK,QAAQ;AACX,iBAAS,KAAK;AAAA,UACZ,MAAM;AAAA,UACN,SAAS,QACN,IAAI,CAAC,SAAS;AACb,oBAAQ,KAAK,MAAM;AAAA,cACjB,KAAK,QAAQ;AACX,uBAAO,KAAK;AAAA,cACd;AAAA,cACA,KAAK,SAAS;AACZ,sBAAM,IAAI,8BAA8B;AAAA,kBACtC,eAAe;AAAA,gBACjB,CAAC;AAAA,cACH;AAAA,YACF;AAAA,UACF,CAAC,EACA,KAAK,EAAE;AAAA,QACZ,CAAC;AACD;AAAA,MACF;AAAA,MAEA,KAAK,aAAa;AAChB,YAAI,OAAO;AACX,cAAM,YAID,CAAC;AAEN,mBAAW,QAAQ,SAAS;AAC1B,kBAAQ,KAAK,MAAM;AAAA,YACjB,KAAK,QAAQ;AACX,sBAAQ,KAAK;AACb;AAAA,YACF;AAAA,YACA,KAAK,aAAa;AAChB,qBAAO,KAAK,UAAU,EAAE,MAAM,KAAK,UAAU,YAAY,KAAK,KAAK,CAAC;AAEpE,wBAAU,KAAK;AAAA,gBACb,IAAI,KAAK;AAAA,gBACT,MAAM;AAAA,gBACN,UAAU;AAAA,kBACR,MAAM,KAAK;AAAA,kBACX,WAAW,KAAK,UAAU,KAAK,IAAI;AAAA,gBACrC;AAAA,cACF,CAAC;AACD;AAAA,YACF;AAAA,YACA,SAAS;AACP,oBAAM,kBAAkB;AACxB,oBAAM,IAAI,MAAM,qBAAqB,eAAe,EAAE;AAAA,YACxD;AAAA,UACF;AAAA,QACF;AAEA,iBAAS,KAAK;AAAA,UACZ,MAAM;AAAA,UACN,SAAS;AAAA,UACT,YACE,UAAU,SAAS,IACf,UAAU,IAAI,CAAC,EAAE,UAAU,EAAE,MAAM,WAAW,KAAK,EAAE,OAAO;AAAA,YAC1D,IAAI;AAAA,YACJ,MAAM;AAAA,YACN,UAAU,EAAE,MAAM,WAAW,KAAK;AAAA,UACpC,EAAE,IACF;AAAA,QACR,CAAC;AAED;AAAA,MACF;AAAA,MACA,KAAK,QAAQ;AACX,mBAAW,gBAAgB,SAAS;AAClC,mBAAS,KAAK;AAAA,YACZ,MAAM;AAAA,YACN,MAAM,aAAa;AAAA,YACnB,SAAS,KAAK,UAAU,aAAa,MAAM;AAAA,UAC7C,CAAC;AAAA,QACH;AACA;AAAA,MACF;AAAA,MACA,SAAS;AACP,cAAM,kBAAkB;AACxB,cAAM,IAAI,MAAM,qBAAqB,eAAe,EAAE;AAAA,MACxD;AAAA,IACF;AAAA,EACF;AAEA,SAAO;AACT;;;ADhGA,SAAS,cAAc;;;AEXhB,SAAS,kBAAkB,QAAuD;AACxF,QAAM,QAAS,OAA0E,SAAS;AAAA,IACjG,eAAe;AAAA,IACf,mBAAmB;AAAA,EACpB;AAEA,SAAO;AAAA,IACN,cAAc,MAAM;AAAA,IACpB,kBAAkB,MAAM;AAAA,EACzB;AACD;;;AFUO,IAAM,6BAAN,MAA4D;AAAA,EASjE,YACE,SACA,UACA,QACA;AAZF,wBAAS,wBAAuB;AAChC,wBAAS,+BAA8B;AAEvC,wBAAS;AACT,wBAAS;AAET,wBAAiB;AAOf,SAAK,UAAU;AACf,SAAK,WAAW;AAChB,SAAK,SAAS;AAAA,EAChB;AAAA,EAEA,IAAI,WAAmB;AACrB,WAAO,KAAK,OAAO;AAAA,EACrB;AAAA,EAEQ,QAAQ;AAAA,IACd;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,EACF,GAAiD;AAC/C,UAAM,OAAO,KAAK;AAElB,UAAM,WAAyC,CAAC;AAEhD,QAAI,oBAAoB,MAAM;AAC5B,eAAS,KAAK;AAAA,QACZ,MAAM;AAAA,QACN,SAAS;AAAA,MACX,CAAC;AAAA,IACH;AAEA,QAAI,mBAAmB,MAAM;AAC3B,eAAS,KAAK;AAAA,QACZ,MAAM;AAAA,QACN,SAAS;AAAA,MACX,CAAC;AAAA,IACH;AAEA,UAAM,WAAW;AAAA;AAAA,MAEf,OAAO,KAAK;AAAA;AAAA,MAGZ,aAAa,KAAK,SAAS;AAAA;AAAA,MAG3B,YAAY;AAAA,MACZ;AAAA,MACA,OAAO;AAAA,MACP,aAAa;AAAA;AAAA,MAGb,UAAU,+BAA+B,MAAM;AAAA,IACjD;AAEA,YAAQ,MAAM;AAAA,MACZ,KAAK,WAAW;AACd,eAAO;AAAA,UACL,MAAM,EAAE,GAAG,UAAU,GAAG,0BAA0B,IAAI,EAAE;AAAA,UACxD;AAAA,QACF;AAAA,MACF;AAAA,MAEA,KAAK,eAAe;AAClB,eAAO;AAAA,UACL,MAAM;AAAA,YACJ,GAAG;AAAA,YACH,iBAAiB;AAAA,cACf,MAAM;AAAA,cACN,aAAa,KAAK;AAAA,YACpB;AAAA,YACA,OAAO;AAAA,UACT;AAAA,UACA;AAAA,QACF;AAAA,MACF;AAAA,MAEA,KAAK,eAAe;AAClB,eAAO;AAAA,UACL,MAAM;AAAA,YACJ,GAAG;AAAA,YACH,aAAa;AAAA,YACb,OAAO,CAAC,EAAE,MAAM,YAAY,UAAU,KAAK,KAAK,CAAC;AAAA,UACnD;AAAA,UACA;AAAA,QACF;AAAA,MACF;AAAA;AAAA;AAAA,MAIA,KAAK,kBAAkB;AACrB,cAAM,IAAIC,+BAA8B;AAAA,UACtC,eAAe;AAAA,QACjB,CAAC;AAAA,MACH;AAAA,MAEA,SAAS;AACP,cAAM,kBAAkB;AACxB,cAAM,IAAI,MAAM,qBAAqB,eAAe,EAAE;AAAA,MACxD;AAAA,IACF;AAAA,EACF;AAAA,EAEA,MAAM,WACJ,SAC6D;AAC7D,UAAM,EAAE,MAAM,SAAS,IAAI,KAAK,QAAQ,OAAO;AAE/C,UAAM,SAAS,MAAM,KAAK,OAAO,QAAQ;AAAA,MACvC,KAAK;AAAA,MACL;AAAA,QACE,UAAU,KAAK;AAAA,QACf,YAAY,KAAK;AAAA,QACjB,aAAa,KAAK;AAAA,QAClB,OAAO,KAAK;AAAA,QACZ,OAAO,KAAK;AAAA;AAAA,QAEZ,iBAAiB,KAAK;AAAA,MACxB;AAAA,MACA,EAAE,SAAS,KAAK,OAAO,WAAW,KAAK,SAAS,QAAQ;AAAA,IAC1D;AAEA,QAAI,kBAAkB,gBAAgB;AACpC,YAAM,IAAI,MAAM,uBAAuB;AAAA,IACzC;AAEA,WAAO;AAAA,MACL,MACE,OAAO,OAAO,aAAa,YAAY,OAAO,aAAa,OACvD,KAAK,UAAU,OAAO,QAAQ,IAC9B,OAAO;AAAA,MACb,WAAW,OAAO,YAAY,IAAI,CAAC,cAAc;AAAA,QAC/C,cAAc;AAAA,QACd,YAAY,SAAS;AAAA,QACrB,UAAU,SAAS;AAAA,QACnB,MAAM,KAAK,UAAU,SAAS,aAAa,CAAC,CAAC;AAAA,MAC/C,EAAE;AAAA,MACF,cAAc;AAAA;AAAA,MACd,SAAS,EAAE,WAAW,KAAK,UAAU,aAAa,KAAK;AAAA,MACvD,OAAO,kBAAkB,MAAM;AAAA,MAC/B;AAAA,IACF;AAAA,EACF;AAAA,EAEA,MAAM,SACJ,SAC2D;AAC3D,UAAM,EAAE,MAAM,SAAS,IAAI,KAAK,QAAQ,OAAO;AAE/C,UAAM,WAAW,MAAM,KAAK,OAAO,QAAQ;AAAA,MACzC,KAAK;AAAA,MACL;AAAA,QACE,UAAU,KAAK;AAAA,QACf,YAAY,KAAK;AAAA,QACjB,QAAQ;AAAA,QACR,aAAa,KAAK;AAAA,QAClB,OAAO,KAAK;AAAA,QACZ,OAAO,KAAK;AAAA,MACd;AAAA,MACA,EAAE,SAAS,KAAK,OAAO,WAAW,KAAK,SAAS,QAAQ;AAAA,IAC1D;AAEA,QAAI,EAAE,oBAAoB,iBAAiB;AACzC,YAAM,IAAI,MAAM,uBAAuB;AAAA,IACzC;AAEA,UAAM,aAAa,OAAO,IAAI,SAAS,QAAQ,CAAC;AAChD,QAAI,QAAQ,EAAE,cAAc,GAAG,kBAAkB,EAAE;AAEnD,WAAO;AAAA,MACL,QAAQ,IAAI,eAA0C;AAAA,QACpD,MAAM,MAAM,YAAY;AACtB,2BAAiB,SAAS,YAAY;AACpC,gBAAI,CAAC,MAAM,MAAM;AACf;AAAA,YACF;AACA,gBAAI,MAAM,SAAS,UAAU;AAC3B;AAAA,YACF;AACA,kBAAM,QAAQ,KAAK,MAAM,MAAM,IAAI;AACnC,gBAAI,MAAM,OAAO;AACf,sBAAQ,kBAAkB,KAAK;AAAA,YACjC;AACA,kBAAM,SAAS,UACb,WAAW,QAAQ;AAAA,cACjB,MAAM;AAAA,cACN,WAAW,MAAM;AAAA,YACnB,CAAC;AAAA,UACL;AACA,qBAAW,QAAQ;AAAA,YACjB,MAAM;AAAA,YACN,cAAc;AAAA,YACd;AAAA,UACF,CAAC;AACD,qBAAW,MAAM;AAAA,QACnB;AAAA,MACF,CAAC;AAAA,MACD,SAAS,EAAE,WAAW,KAAK,UAAU,aAAa,KAAK;AAAA,MACvD;AAAA,IACF;AAAA,EACF;AACF;AAGA,IAAM,8BAA8B,EAAE,OAAO;AAAA,EAC3C,UAAU,EAAE,OAAO;AACrB,CAAC;AAID,IAAM,2BAA2B,EAAE,WAAW,UAAU;AAExD,SAAS,0BACP,MAGA;AAEA,QAAM,QAAQ,KAAK,OAAO,SAAS,KAAK,QAAQ;AAEhD,MAAI,SAAS,MAAM;AACjB,WAAO,EAAE,OAAO,QAAW,aAAa,OAAU;AAAA,EACpD;AAEA,QAAM,cAAc,MAAM,IAAI,CAAC,UAAU;AAAA,IACvC,MAAM;AAAA,IACN,UAAU;AAAA,MACR,MAAM,KAAK;AAAA;AAAA,MAEX,aAAa,KAAK;AAAA;AAAA,MAElB,YAAY,KAAK;AAAA,IACnB;AAAA,EACF,EAAE;AAEF,QAAM,aAAa,KAAK;AAExB,MAAI,cAAc,MAAM;AACtB,WAAO,EAAE,OAAO,aAAa,aAAa,OAAU;AAAA,EACtD;AAEA,QAAM,OAAO,WAAW;AAExB,UAAQ,MAAM;AAAA,IACZ,KAAK;AACH,aAAO,EAAE,OAAO,aAAa,aAAa,KAAK;AAAA,IACjD,KAAK;AACH,aAAO,EAAE,OAAO,aAAa,aAAa,KAAK;AAAA,IACjD,KAAK;AACH,aAAO,EAAE,OAAO,aAAa,aAAa,MAAM;AAAA;AAAA;AAAA,IAIlD,KAAK;AACH,aAAO;AAAA,QACL,OAAO,YAAY;AAAA,UACjB,CAAC,SAAS,KAAK,SAAS,SAAS,WAAW;AAAA,QAC9C;AAAA,QACA,aAAa;AAAA,MACf;AAAA,IACF,SAAS;AACP,YAAM,kBAAkB;AACxB,YAAM,IAAI,MAAM,iCAAiC,eAAe,EAAE;AAAA,IACpE;AAAA,EACF;AACF;;;AGlQO,SAAS,gBAAgB,SAAuC;AACrE,QAAM,kBAAkB,CACtB,SACA,WAAkC,CAAC,MAEnC,IAAI,2BAA2B,SAAS,UAAU;AAAA,IAChD,UAAU;AAAA,IACV,SAAS,QAAQ;AAAA,IACjB,SAAS,QAAQ;AAAA,EACnB,CAAC;AAEH,QAAM,WAAW,SACf,SACA,UACA;AACA,QAAI,YAAY;AACd,YAAM,IAAI;AAAA,QACR;AAAA,MACF;AAAA,IACF;AAEA,WAAO,gBAAgB,SAAS,QAAQ;AAAA,EAC1C;AAEA,WAAS,OAAO;AAEhB,SAAO;AACT;","names":["UnsupportedFunctionalityError","UnsupportedFunctionalityError"]}
|
1
|
+
{"version":3,"sources":["../src/workersai-chat-language-model.ts","../src/convert-to-workersai-chat-messages.ts","../src/map-workersai-usage.ts","../src/utils.ts","../src/index.ts"],"sourcesContent":["import {\n type LanguageModelV1,\n type LanguageModelV1CallWarning,\n type LanguageModelV1StreamPart,\n UnsupportedFunctionalityError,\n} from \"@ai-sdk/provider\";\nimport { z } from \"zod\";\nimport { convertToWorkersAIChatMessages } from \"./convert-to-workersai-chat-messages\";\nimport type { WorkersAIChatSettings } from \"./workersai-chat-settings\";\nimport type { TextGenerationModels } from \"./workersai-models\";\n\nimport { events } from \"fetch-event-stream\";\nimport { mapWorkersAIUsage } from \"./map-workersai-usage\";\nimport type { WorkersAIChatPrompt } from \"./workersai-chat-prompt\";\n\ntype WorkersAIChatConfig = {\n provider: string;\n binding: Ai;\n gateway?: GatewayOptions;\n};\n\nexport class WorkersAIChatLanguageModel implements LanguageModelV1 {\n readonly specificationVersion = \"v1\";\n readonly defaultObjectGenerationMode = \"json\";\n\n readonly modelId: TextGenerationModels;\n readonly settings: WorkersAIChatSettings;\n\n private readonly config: WorkersAIChatConfig;\n\n constructor(\n modelId: TextGenerationModels,\n settings: WorkersAIChatSettings,\n config: WorkersAIChatConfig\n ) {\n this.modelId = modelId;\n this.settings = settings;\n this.config = config;\n }\n\n get provider(): string {\n return this.config.provider;\n }\n\n private getArgs({\n mode,\n prompt,\n maxTokens,\n temperature,\n topP,\n frequencyPenalty,\n presencePenalty,\n seed,\n }: Parameters<LanguageModelV1[\"doGenerate\"]>[0]) {\n const type = mode.type;\n\n const warnings: LanguageModelV1CallWarning[] = [];\n\n if (frequencyPenalty != null) {\n warnings.push({\n type: \"unsupported-setting\",\n setting: \"frequencyPenalty\",\n });\n }\n\n if (presencePenalty != null) {\n warnings.push({\n type: \"unsupported-setting\",\n setting: \"presencePenalty\",\n });\n }\n\n const baseArgs = {\n // model id:\n model: this.modelId,\n\n // model specific settings:\n safe_prompt: this.settings.safePrompt,\n\n // standardized settings:\n max_tokens: maxTokens,\n temperature,\n top_p: topP,\n random_seed: seed,\n\n // messages:\n messages: convertToWorkersAIChatMessages(prompt),\n };\n\n switch (type) {\n case \"regular\": {\n return {\n args: { ...baseArgs, ...prepareToolsAndToolChoice(mode) },\n warnings,\n };\n }\n\n case \"object-json\": {\n return {\n args: {\n ...baseArgs,\n response_format: {\n type: \"json_schema\",\n json_schema: mode.schema,\n },\n tools: undefined,\n },\n warnings,\n };\n }\n\n case \"object-tool\": {\n return {\n args: {\n ...baseArgs,\n tool_choice: \"any\",\n tools: [{ type: \"function\", function: mode.tool }],\n },\n warnings,\n };\n }\n\n // @ts-expect-error - this is unreachable code\n // TODO: fixme\n case \"object-grammar\": {\n throw new UnsupportedFunctionalityError({\n functionality: \"object-grammar mode\",\n });\n }\n\n default: {\n const exhaustiveCheck = type satisfies never;\n throw new Error(`Unsupported type: ${exhaustiveCheck}`);\n }\n }\n }\n\n async doGenerate(\n options: Parameters<LanguageModelV1[\"doGenerate\"]>[0]\n ): Promise<Awaited<ReturnType<LanguageModelV1[\"doGenerate\"]>>> {\n const { args, warnings } = this.getArgs(options);\n\n const output = await this.config.binding.run(\n args.model,\n {\n messages: args.messages,\n max_tokens: args.max_tokens,\n temperature: args.temperature,\n tools: args.tools,\n top_p: args.top_p,\n // @ts-expect-error response_format not yet added to types\n response_format: args.response_format,\n },\n { gateway: this.config.gateway ?? this.settings.gateway }\n );\n\n if (output instanceof ReadableStream) {\n throw new Error(\"This shouldn't happen\");\n }\n\n return {\n text:\n typeof output.response === \"object\" && output.response !== null\n ? JSON.stringify(output.response) // ai-sdk expects a string here\n : output.response,\n toolCalls: output.tool_calls?.map((toolCall) => ({\n toolCallType: \"function\",\n toolCallId: toolCall.name,\n toolName: toolCall.name,\n args: JSON.stringify(toolCall.arguments || {}),\n })),\n finishReason: \"stop\", // TODO: mapWorkersAIFinishReason(response.finish_reason),\n rawCall: { rawPrompt: args.messages, rawSettings: args },\n usage: mapWorkersAIUsage(output),\n warnings,\n };\n }\n\n async doStream(\n options: Parameters<LanguageModelV1[\"doStream\"]>[0]\n ): Promise<Awaited<ReturnType<LanguageModelV1[\"doStream\"]>>> {\n const { args, warnings } = this.getArgs(options);\n\n // [1] When the latest message is not a tool response, we use the regular generate function\n // and simulate it as a streamed response in order to satisfy the AI SDK's interface for\n // doStream...\n if (args.tools?.length && lastMessageWasUser(args.messages)) {\n const response = await this.doGenerate(options);\n\n if ((response instanceof ReadableStream)) {\n throw new Error(\"This shouldn't happen\");\n }\n\n return {\n stream: new ReadableStream<LanguageModelV1StreamPart>({\n async start(controller) {\n if (response.text) {\n controller.enqueue({\n type: \"text-delta\",\n textDelta: response.text,\n })\n }\n if (response.toolCalls) {\n for (const toolCall of response.toolCalls) {\n controller.enqueue({\n type: \"tool-call\",\n ...toolCall,\n })\n }\n }\n controller.enqueue({\n type: \"finish\",\n finishReason: \"stop\",\n usage: response.usage,\n });\n controller.close();\n },\n }),\n rawCall: { rawPrompt: args.messages, rawSettings: args },\n warnings,\n };\n }\n\n // [2] ...otherwise, we just proceed as normal and stream the response directly from the remote model.\n const response = await this.config.binding.run(\n args.model,\n {\n messages: args.messages,\n max_tokens: args.max_tokens,\n stream: true,\n temperature: args.temperature,\n tools: args.tools,\n top_p: args.top_p,\n // @ts-expect-error response_format not yet added to types\n response_format: args.response_format,\n },\n { gateway: this.config.gateway ?? this.settings.gateway }\n );\n\n if (!(response instanceof ReadableStream)) {\n throw new Error(\"This shouldn't happen\");\n }\n\n const chunkEvent = events(new Response(response));\n let usage = { promptTokens: 0, completionTokens: 0 };\n\n return {\n stream: new ReadableStream<LanguageModelV1StreamPart>({\n async start(controller) {\n for await (const event of chunkEvent) {\n if (!event.data) {\n continue;\n }\n if (event.data === \"[DONE]\") {\n break;\n }\n const chunk = JSON.parse(event.data);\n if (chunk.usage) {\n usage = mapWorkersAIUsage(chunk);\n }\n chunk.response.length &&\n controller.enqueue({\n type: \"text-delta\",\n textDelta: chunk.response,\n });\n }\n controller.enqueue({\n type: \"finish\",\n finishReason: \"stop\",\n usage: usage,\n });\n controller.close();\n },\n }),\n rawCall: { rawPrompt: args.messages, rawSettings: args },\n warnings,\n };\n }\n}\n// limited version of the schema, focussed on what is needed for the implementation\n// this approach limits breakages when the API changes and increases efficiency\nconst workersAIChatResponseSchema = z.object({\n response: z.string(),\n});\n\n// limited version of the schema, focussed on what is needed for the implementation\n// this approach limits breakages when the API changes and increases efficiency\nconst workersAIChatChunkSchema = z.instanceof(Uint8Array);\n\nfunction prepareToolsAndToolChoice(\n mode: Parameters<LanguageModelV1[\"doGenerate\"]>[0][\"mode\"] & {\n type: \"regular\";\n }\n) {\n // when the tools array is empty, change it to undefined to prevent errors:\n const tools = mode.tools?.length ? mode.tools : undefined;\n\n if (tools == null) {\n return { tools: undefined, tool_choice: undefined };\n }\n\n const mappedTools = tools.map((tool) => ({\n type: \"function\",\n function: {\n name: tool.name,\n // @ts-expect-error - description is not a property of tool\n description: tool.description,\n // @ts-expect-error - parameters is not a property of tool\n parameters: tool.parameters,\n },\n }));\n\n const toolChoice = mode.toolChoice;\n\n if (toolChoice == null) {\n return { tools: mappedTools, tool_choice: undefined };\n }\n\n const type = toolChoice.type;\n\n switch (type) {\n case \"auto\":\n return { tools: mappedTools, tool_choice: type };\n case \"none\":\n return { tools: mappedTools, tool_choice: type };\n case \"required\":\n return { tools: mappedTools, tool_choice: \"any\" };\n\n // workersAI does not support tool mode directly,\n // so we filter the tools and force the tool choice through 'any'\n case \"tool\":\n return {\n tools: mappedTools.filter(\n (tool) => tool.function.name === toolChoice.toolName\n ),\n tool_choice: \"any\",\n };\n default: {\n const exhaustiveCheck = type satisfies never;\n throw new Error(`Unsupported tool choice type: ${exhaustiveCheck}`);\n }\n }\n}\n\nfunction lastMessageWasUser(messages: WorkersAIChatPrompt) {\n return messages.length > 0 && messages[messages.length - 1].role === \"user\";\n}\n","import {\n type LanguageModelV1Prompt,\n UnsupportedFunctionalityError,\n} from \"@ai-sdk/provider\";\nimport type { WorkersAIChatPrompt } from \"./workersai-chat-prompt\";\n\n// TODO\nexport function convertToWorkersAIChatMessages(\n prompt: LanguageModelV1Prompt\n): WorkersAIChatPrompt {\n const messages: WorkersAIChatPrompt = [];\n\n for (const { role, content } of prompt) {\n switch (role) {\n case \"system\": {\n messages.push({ role: \"system\", content });\n break;\n }\n\n case \"user\": {\n messages.push({\n role: \"user\",\n content: content\n .map((part) => {\n switch (part.type) {\n case \"text\": {\n return part.text;\n }\n case \"image\": {\n throw new UnsupportedFunctionalityError({\n functionality: \"image-part\",\n });\n }\n }\n })\n .join(\"\"),\n });\n break;\n }\n\n case \"assistant\": {\n let text = \"\";\n const toolCalls: Array<{\n id: string;\n type: \"function\";\n function: { name: string; arguments: string };\n }> = [];\n\n for (const part of content) {\n switch (part.type) {\n case \"text\": {\n text += part.text;\n break;\n }\n case \"tool-call\": {\n text = JSON.stringify({ name: part.toolName, parameters: part.args });\n\n toolCalls.push({\n id: part.toolCallId,\n type: \"function\",\n function: {\n name: part.toolName,\n arguments: JSON.stringify(part.args),\n },\n });\n break;\n }\n default: {\n const exhaustiveCheck = part satisfies never;\n throw new Error(`Unsupported part: ${exhaustiveCheck}`);\n }\n }\n }\n\n messages.push({\n role: \"assistant\",\n content: text,\n tool_calls:\n toolCalls.length > 0\n ? toolCalls.map(({ function: { name, arguments: args } }) => ({\n id: \"null\",\n type: \"function\",\n function: { name, arguments: args },\n }))\n : undefined,\n });\n\n break;\n }\n case \"tool\": {\n for (const toolResponse of content) {\n messages.push({\n role: \"tool\",\n name: toolResponse.toolName,\n content: JSON.stringify(toolResponse.result),\n });\n }\n break;\n }\n default: {\n const exhaustiveCheck = role satisfies never;\n throw new Error(`Unsupported role: ${exhaustiveCheck}`);\n }\n }\n }\n\n return messages;\n}\n","export function mapWorkersAIUsage(output: AiTextGenerationOutput | AiTextToImageOutput) {\n\tconst usage = (output as { usage: { prompt_tokens: number, completion_tokens: number} }).usage ?? {\n\t\tprompt_tokens: 0,\n\t\tcompletion_tokens: 0,\n\t};\n\n\treturn {\n\t\tpromptTokens: usage.prompt_tokens,\n\t\tcompletionTokens: usage.completion_tokens\n\t}\n}\n","/**\n * Creates a run method that mimics the Cloudflare Workers AI binding,\n * but uses the Cloudflare REST API under the hood.\n *\n * @param accountId - Your Cloudflare account identifier.\n * @param apiKey - Your Cloudflare API token/key with appropriate permissions.\n * @returns An function matching `Ai['run']`.\n */\nexport function createRun(accountId: string, apiKey: string): AiRun {\n return async <Name extends keyof AiModels>(model: Name, inputs: AiModels[Name][\"inputs\"], options?: AiOptions | undefined) => {\n const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}`;\n const body = JSON.stringify(inputs);\n\n const headers = {\n \"Content-Type\": \"application/json\",\n Authorization: `Bearer ${apiKey}`,\n };\n\n const response = await fetch(url, {\n method: \"POST\",\n headers,\n body,\n }) as Response;\n\n if (options?.returnRawResponse) {\n return response;\n }\n\n if ((inputs as AiTextGenerationInput).stream === true) {\n // If there's a stream, return the raw body so the caller can process it\n if (response.body) {\n return response.body;\n }\n throw new Error(\"No readable body available for streaming.\");\n }\n\n // Otherwise, parse JSON and return the data.result\n const data = await response.json<{ result: AiModels[Name][\"postProcessedOutputs\"] }>();\n return data.result;\n };\n}\n\ninterface AiRun {\n // (1) Return raw response if `options.returnRawResponse` is `true`.\n <Name extends keyof AiModels>(\n model: Name,\n inputs: AiModels[Name][\"inputs\"],\n options: AiOptions & { returnRawResponse: true }\n ): Promise<Response>;\n\n // (2) Return a stream if the input has `stream: true`.\n <Name extends keyof AiModels>(\n model: Name,\n inputs: AiModels[Name][\"inputs\"] & { stream: true },\n options?: AiOptions\n ): Promise<ReadableStream<Uint8Array>>;\n\n // (3) Return the post-processed outputs by default.\n <Name extends keyof AiModels>(\n model: Name,\n inputs: AiModels[Name][\"inputs\"],\n options?: AiOptions\n ): Promise<AiModels[Name][\"postProcessedOutputs\"]>;\n}\n","import { WorkersAIChatLanguageModel } from \"./workersai-chat-language-model\";\nimport type { WorkersAIChatSettings } from \"./workersai-chat-settings\";\nimport type { TextGenerationModels } from \"./workersai-models\";\nimport { createRun } from \"./utils\";\n\nexport type WorkersAISettings =\n ({\n /**\n * Provide a Cloudflare AI binding.\n */\n binding: Ai;\n\n /**\n * Credentials must be absent when a binding is given.\n */\n accountId?: never;\n apiKey?: never;\n}\n | {\n /**\n * Provide Cloudflare API credentials directly. Must be used if a binding is not specified.\n */\n accountId: string;\n apiKey: string;\n /**\n * Both binding must be absent if credentials are used directly.\n */\n binding?: never;\n}) & {\n /**\n * Optionally specify a gateway.\n */\n gateway?: GatewayOptions;\n};\n\nexport interface WorkersAI {\n (\n modelId: TextGenerationModels,\n settings?: WorkersAIChatSettings\n ): WorkersAIChatLanguageModel;\n\n /**\n * Creates a model for text generation.\n **/\n chat(\n modelId: TextGenerationModels,\n settings?: WorkersAIChatSettings\n ): WorkersAIChatLanguageModel;\n}\n\n/**\n * Create a Workers AI provider instance.\n */\nexport function createWorkersAI(options: WorkersAISettings): WorkersAI {\n // Use a binding if one is directly provided. Otherwise use credentials to create\n // a `run` method that calls the Cloudflare REST API.\n let binding: Ai | undefined;\n\n if (options.binding) {\n binding = options.binding;\n } else {\n const { accountId, apiKey } = options;\n binding = {\n run: createRun(accountId, apiKey),\n } as Ai;\n }\n\n if (!binding) {\n throw new Error(\n \"Either a binding or credentials must be provided.\"\n );\n }\n\n /**\n * Helper function to create a chat model instance.\n */\n const createChatModel = (\n modelId: TextGenerationModels,\n settings: WorkersAIChatSettings = {}\n ) =>\n new WorkersAIChatLanguageModel(modelId, settings, {\n provider: \"workersai.chat\",\n binding,\n gateway: options.gateway\n });\n\n const provider = function (\n modelId: TextGenerationModels,\n settings?: WorkersAIChatSettings\n ) {\n if (new.target) {\n throw new Error(\n \"The WorkersAI model function cannot be called with the new keyword.\"\n );\n }\n return createChatModel(modelId, settings);\n };\n\n provider.chat = createChatModel;\n\n return provider;\n}\n\n"],"mappings":";;;;;AAAA;AAAA,EAIE,iCAAAA;AAAA,OACK;AACP,SAAS,SAAS;;;ACNlB;AAAA,EAEE;AAAA,OACK;AAIA,SAAS,+BACd,QACqB;AACrB,QAAM,WAAgC,CAAC;AAEvC,aAAW,EAAE,MAAM,QAAQ,KAAK,QAAQ;AACtC,YAAQ,MAAM;AAAA,MACZ,KAAK,UAAU;AACb,iBAAS,KAAK,EAAE,MAAM,UAAU,QAAQ,CAAC;AACzC;AAAA,MACF;AAAA,MAEA,KAAK,QAAQ;AACX,iBAAS,KAAK;AAAA,UACZ,MAAM;AAAA,UACN,SAAS,QACN,IAAI,CAAC,SAAS;AACb,oBAAQ,KAAK,MAAM;AAAA,cACjB,KAAK,QAAQ;AACX,uBAAO,KAAK;AAAA,cACd;AAAA,cACA,KAAK,SAAS;AACZ,sBAAM,IAAI,8BAA8B;AAAA,kBACtC,eAAe;AAAA,gBACjB,CAAC;AAAA,cACH;AAAA,YACF;AAAA,UACF,CAAC,EACA,KAAK,EAAE;AAAA,QACZ,CAAC;AACD;AAAA,MACF;AAAA,MAEA,KAAK,aAAa;AAChB,YAAI,OAAO;AACX,cAAM,YAID,CAAC;AAEN,mBAAW,QAAQ,SAAS;AAC1B,kBAAQ,KAAK,MAAM;AAAA,YACjB,KAAK,QAAQ;AACX,sBAAQ,KAAK;AACb;AAAA,YACF;AAAA,YACA,KAAK,aAAa;AAChB,qBAAO,KAAK,UAAU,EAAE,MAAM,KAAK,UAAU,YAAY,KAAK,KAAK,CAAC;AAEpE,wBAAU,KAAK;AAAA,gBACb,IAAI,KAAK;AAAA,gBACT,MAAM;AAAA,gBACN,UAAU;AAAA,kBACR,MAAM,KAAK;AAAA,kBACX,WAAW,KAAK,UAAU,KAAK,IAAI;AAAA,gBACrC;AAAA,cACF,CAAC;AACD;AAAA,YACF;AAAA,YACA,SAAS;AACP,oBAAM,kBAAkB;AACxB,oBAAM,IAAI,MAAM,qBAAqB,eAAe,EAAE;AAAA,YACxD;AAAA,UACF;AAAA,QACF;AAEA,iBAAS,KAAK;AAAA,UACZ,MAAM;AAAA,UACN,SAAS;AAAA,UACT,YACE,UAAU,SAAS,IACf,UAAU,IAAI,CAAC,EAAE,UAAU,EAAE,MAAM,WAAW,KAAK,EAAE,OAAO;AAAA,YAC1D,IAAI;AAAA,YACJ,MAAM;AAAA,YACN,UAAU,EAAE,MAAM,WAAW,KAAK;AAAA,UACpC,EAAE,IACF;AAAA,QACR,CAAC;AAED;AAAA,MACF;AAAA,MACA,KAAK,QAAQ;AACX,mBAAW,gBAAgB,SAAS;AAClC,mBAAS,KAAK;AAAA,YACZ,MAAM;AAAA,YACN,MAAM,aAAa;AAAA,YACnB,SAAS,KAAK,UAAU,aAAa,MAAM;AAAA,UAC7C,CAAC;AAAA,QACH;AACA;AAAA,MACF;AAAA,MACA,SAAS;AACP,cAAM,kBAAkB;AACxB,cAAM,IAAI,MAAM,qBAAqB,eAAe,EAAE;AAAA,MACxD;AAAA,IACF;AAAA,EACF;AAEA,SAAO;AACT;;;ADhGA,SAAS,cAAc;;;AEXhB,SAAS,kBAAkB,QAAuD;AACxF,QAAM,QAAS,OAA0E,SAAS;AAAA,IACjG,eAAe;AAAA,IACf,mBAAmB;AAAA,EACpB;AAEA,SAAO;AAAA,IACN,cAAc,MAAM;AAAA,IACpB,kBAAkB,MAAM;AAAA,EACzB;AACD;;;AFWO,IAAM,6BAAN,MAA4D;AAAA,EASjE,YACE,SACA,UACA,QACA;AAZF,wBAAS,wBAAuB;AAChC,wBAAS,+BAA8B;AAEvC,wBAAS;AACT,wBAAS;AAET,wBAAiB;AAOf,SAAK,UAAU;AACf,SAAK,WAAW;AAChB,SAAK,SAAS;AAAA,EAChB;AAAA,EAEA,IAAI,WAAmB;AACrB,WAAO,KAAK,OAAO;AAAA,EACrB;AAAA,EAEQ,QAAQ;AAAA,IACd;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,EACF,GAAiD;AAC/C,UAAM,OAAO,KAAK;AAElB,UAAM,WAAyC,CAAC;AAEhD,QAAI,oBAAoB,MAAM;AAC5B,eAAS,KAAK;AAAA,QACZ,MAAM;AAAA,QACN,SAAS;AAAA,MACX,CAAC;AAAA,IACH;AAEA,QAAI,mBAAmB,MAAM;AAC3B,eAAS,KAAK;AAAA,QACZ,MAAM;AAAA,QACN,SAAS;AAAA,MACX,CAAC;AAAA,IACH;AAEA,UAAM,WAAW;AAAA;AAAA,MAEf,OAAO,KAAK;AAAA;AAAA,MAGZ,aAAa,KAAK,SAAS;AAAA;AAAA,MAG3B,YAAY;AAAA,MACZ;AAAA,MACA,OAAO;AAAA,MACP,aAAa;AAAA;AAAA,MAGb,UAAU,+BAA+B,MAAM;AAAA,IACjD;AAEA,YAAQ,MAAM;AAAA,MACZ,KAAK,WAAW;AACd,eAAO;AAAA,UACL,MAAM,EAAE,GAAG,UAAU,GAAG,0BAA0B,IAAI,EAAE;AAAA,UACxD;AAAA,QACF;AAAA,MACF;AAAA,MAEA,KAAK,eAAe;AAClB,eAAO;AAAA,UACL,MAAM;AAAA,YACJ,GAAG;AAAA,YACH,iBAAiB;AAAA,cACf,MAAM;AAAA,cACN,aAAa,KAAK;AAAA,YACpB;AAAA,YACA,OAAO;AAAA,UACT;AAAA,UACA;AAAA,QACF;AAAA,MACF;AAAA,MAEA,KAAK,eAAe;AAClB,eAAO;AAAA,UACL,MAAM;AAAA,YACJ,GAAG;AAAA,YACH,aAAa;AAAA,YACb,OAAO,CAAC,EAAE,MAAM,YAAY,UAAU,KAAK,KAAK,CAAC;AAAA,UACnD;AAAA,UACA;AAAA,QACF;AAAA,MACF;AAAA;AAAA;AAAA,MAIA,KAAK,kBAAkB;AACrB,cAAM,IAAIC,+BAA8B;AAAA,UACtC,eAAe;AAAA,QACjB,CAAC;AAAA,MACH;AAAA,MAEA,SAAS;AACP,cAAM,kBAAkB;AACxB,cAAM,IAAI,MAAM,qBAAqB,eAAe,EAAE;AAAA,MACxD;AAAA,IACF;AAAA,EACF;AAAA,EAEA,MAAM,WACJ,SAC6D;AAC7D,UAAM,EAAE,MAAM,SAAS,IAAI,KAAK,QAAQ,OAAO;AAE/C,UAAM,SAAS,MAAM,KAAK,OAAO,QAAQ;AAAA,MACvC,KAAK;AAAA,MACL;AAAA,QACE,UAAU,KAAK;AAAA,QACf,YAAY,KAAK;AAAA,QACjB,aAAa,KAAK;AAAA,QAClB,OAAO,KAAK;AAAA,QACZ,OAAO,KAAK;AAAA;AAAA,QAEZ,iBAAiB,KAAK;AAAA,MACxB;AAAA,MACA,EAAE,SAAS,KAAK,OAAO,WAAW,KAAK,SAAS,QAAQ;AAAA,IAC1D;AAEA,QAAI,kBAAkB,gBAAgB;AACpC,YAAM,IAAI,MAAM,uBAAuB;AAAA,IACzC;AAEA,WAAO;AAAA,MACL,MACE,OAAO,OAAO,aAAa,YAAY,OAAO,aAAa,OACvD,KAAK,UAAU,OAAO,QAAQ,IAC9B,OAAO;AAAA,MACb,WAAW,OAAO,YAAY,IAAI,CAAC,cAAc;AAAA,QAC/C,cAAc;AAAA,QACd,YAAY,SAAS;AAAA,QACrB,UAAU,SAAS;AAAA,QACnB,MAAM,KAAK,UAAU,SAAS,aAAa,CAAC,CAAC;AAAA,MAC/C,EAAE;AAAA,MACF,cAAc;AAAA;AAAA,MACd,SAAS,EAAE,WAAW,KAAK,UAAU,aAAa,KAAK;AAAA,MACvD,OAAO,kBAAkB,MAAM;AAAA,MAC/B;AAAA,IACF;AAAA,EACF;AAAA,EAEA,MAAM,SACJ,SAC2D;AAC3D,UAAM,EAAE,MAAM,SAAS,IAAI,KAAK,QAAQ,OAAO;AAK/C,QAAI,KAAK,OAAO,UAAU,mBAAmB,KAAK,QAAQ,GAAG;AAC3D,YAAMC,YAAW,MAAM,KAAK,WAAW,OAAO;AAE9C,UAAKA,qBAAoB,gBAAiB;AACxC,cAAM,IAAI,MAAM,uBAAuB;AAAA,MACzC;AAEA,aAAO;AAAA,QACL,QAAQ,IAAI,eAA0C;AAAA,UACpD,MAAM,MAAM,YAAY;AACtB,gBAAIA,UAAS,MAAM;AACjB,yBAAW,QAAQ;AAAA,gBACjB,MAAM;AAAA,gBACN,WAAWA,UAAS;AAAA,cACtB,CAAC;AAAA,YACH;AACA,gBAAIA,UAAS,WAAW;AACtB,yBAAW,YAAYA,UAAS,WAAW;AACzC,2BAAW,QAAQ;AAAA,kBACjB,MAAM;AAAA,kBACN,GAAG;AAAA,gBACL,CAAC;AAAA,cACH;AAAA,YACF;AACA,uBAAW,QAAQ;AAAA,cACjB,MAAM;AAAA,cACN,cAAc;AAAA,cACd,OAAOA,UAAS;AAAA,YAClB,CAAC;AACD,uBAAW,MAAM;AAAA,UACnB;AAAA,QACF,CAAC;AAAA,QACD,SAAS,EAAE,WAAW,KAAK,UAAU,aAAa,KAAK;AAAA,QACvD;AAAA,MACF;AAAA,IACF;AAGA,UAAM,WAAW,MAAM,KAAK,OAAO,QAAQ;AAAA,MACzC,KAAK;AAAA,MACL;AAAA,QACE,UAAU,KAAK;AAAA,QACf,YAAY,KAAK;AAAA,QACjB,QAAQ;AAAA,QACR,aAAa,KAAK;AAAA,QAClB,OAAO,KAAK;AAAA,QACZ,OAAO,KAAK;AAAA;AAAA,QAEZ,iBAAiB,KAAK;AAAA,MACxB;AAAA,MACA,EAAE,SAAS,KAAK,OAAO,WAAW,KAAK,SAAS,QAAQ;AAAA,IAC1D;AAEA,QAAI,EAAE,oBAAoB,iBAAiB;AACzC,YAAM,IAAI,MAAM,uBAAuB;AAAA,IACzC;AAEA,UAAM,aAAa,OAAO,IAAI,SAAS,QAAQ,CAAC;AAChD,QAAI,QAAQ,EAAE,cAAc,GAAG,kBAAkB,EAAE;AAEnD,WAAO;AAAA,MACL,QAAQ,IAAI,eAA0C;AAAA,QACpD,MAAM,MAAM,YAAY;AACtB,2BAAiB,SAAS,YAAY;AACpC,gBAAI,CAAC,MAAM,MAAM;AACf;AAAA,YACF;AACA,gBAAI,MAAM,SAAS,UAAU;AAC3B;AAAA,YACF;AACA,kBAAM,QAAQ,KAAK,MAAM,MAAM,IAAI;AACnC,gBAAI,MAAM,OAAO;AACf,sBAAQ,kBAAkB,KAAK;AAAA,YACjC;AACA,kBAAM,SAAS,UACb,WAAW,QAAQ;AAAA,cACjB,MAAM;AAAA,cACN,WAAW,MAAM;AAAA,YACnB,CAAC;AAAA,UACL;AACA,qBAAW,QAAQ;AAAA,YACjB,MAAM;AAAA,YACN,cAAc;AAAA,YACd;AAAA,UACF,CAAC;AACD,qBAAW,MAAM;AAAA,QACnB;AAAA,MACF,CAAC;AAAA,MACD,SAAS,EAAE,WAAW,KAAK,UAAU,aAAa,KAAK;AAAA,MACvD;AAAA,IACF;AAAA,EACF;AACF;AAGA,IAAM,8BAA8B,EAAE,OAAO;AAAA,EAC3C,UAAU,EAAE,OAAO;AACrB,CAAC;AAID,IAAM,2BAA2B,EAAE,WAAW,UAAU;AAExD,SAAS,0BACP,MAGA;AAEA,QAAM,QAAQ,KAAK,OAAO,SAAS,KAAK,QAAQ;AAEhD,MAAI,SAAS,MAAM;AACjB,WAAO,EAAE,OAAO,QAAW,aAAa,OAAU;AAAA,EACpD;AAEA,QAAM,cAAc,MAAM,IAAI,CAAC,UAAU;AAAA,IACvC,MAAM;AAAA,IACN,UAAU;AAAA,MACR,MAAM,KAAK;AAAA;AAAA,MAEX,aAAa,KAAK;AAAA;AAAA,MAElB,YAAY,KAAK;AAAA,IACnB;AAAA,EACF,EAAE;AAEF,QAAM,aAAa,KAAK;AAExB,MAAI,cAAc,MAAM;AACtB,WAAO,EAAE,OAAO,aAAa,aAAa,OAAU;AAAA,EACtD;AAEA,QAAM,OAAO,WAAW;AAExB,UAAQ,MAAM;AAAA,IACZ,KAAK;AACH,aAAO,EAAE,OAAO,aAAa,aAAa,KAAK;AAAA,IACjD,KAAK;AACH,aAAO,EAAE,OAAO,aAAa,aAAa,KAAK;AAAA,IACjD,KAAK;AACH,aAAO,EAAE,OAAO,aAAa,aAAa,MAAM;AAAA;AAAA;AAAA,IAIlD,KAAK;AACH,aAAO;AAAA,QACL,OAAO,YAAY;AAAA,UACjB,CAAC,SAAS,KAAK,SAAS,SAAS,WAAW;AAAA,QAC9C;AAAA,QACA,aAAa;AAAA,MACf;AAAA,IACF,SAAS;AACP,YAAM,kBAAkB;AACxB,YAAM,IAAI,MAAM,iCAAiC,eAAe,EAAE;AAAA,IACpE;AAAA,EACF;AACF;AAEA,SAAS,mBAAmB,UAA+B;AACzD,SAAO,SAAS,SAAS,KAAK,SAAS,SAAS,SAAS,CAAC,EAAE,SAAS;AACvE;;;AGlVO,SAAS,UAAU,WAAmB,QAAuB;AAClE,SAAO,OAAoC,OAAa,QAAkC,YAAoC;AAC5H,UAAM,MAAM,iDAAiD,SAAS,WAAW,KAAK;AACtF,UAAM,OAAO,KAAK,UAAU,MAAM;AAElC,UAAM,UAAU;AAAA,MACd,gBAAgB;AAAA,MAChB,eAAe,UAAU,MAAM;AAAA,IACjC;AAEA,UAAM,WAAW,MAAM,MAAM,KAAK;AAAA,MAChC,QAAQ;AAAA,MACR;AAAA,MACA;AAAA,IACF,CAAC;AAED,QAAI,SAAS,mBAAmB;AAC9B,aAAO;AAAA,IACT;AAEA,QAAK,OAAiC,WAAW,MAAM;AAErD,UAAI,SAAS,MAAM;AACjB,eAAO,SAAS;AAAA,MAClB;AACA,YAAM,IAAI,MAAM,2CAA2C;AAAA,IAC7D;AAGA,UAAM,OAAO,MAAM,SAAS,KAAyD;AACrF,WAAO,KAAK;AAAA,EACd;AACF;;;ACaO,SAAS,gBAAgB,SAAuC;AAGrE,MAAI;AAEJ,MAAI,QAAQ,SAAS;AACnB,cAAU,QAAQ;AAAA,EACpB,OAAO;AACL,UAAM,EAAE,WAAW,OAAO,IAAI;AAC9B,cAAU;AAAA,MACR,KAAK,UAAU,WAAW,MAAM;AAAA,IAClC;AAAA,EACF;AAEA,MAAI,CAAC,SAAS;AACZ,UAAM,IAAI;AAAA,MACR;AAAA,IACF;AAAA,EACF;AAKA,QAAM,kBAAkB,CACtB,SACA,WAAkC,CAAC,MAEnC,IAAI,2BAA2B,SAAS,UAAU;AAAA,IAChD,UAAU;AAAA,IACV;AAAA,IACA,SAAS,QAAQ;AAAA,EACnB,CAAC;AAEH,QAAM,WAAW,SACf,SACA,UACA;AACA,QAAI,YAAY;AACd,YAAM,IAAI;AAAA,QACR;AAAA,MACF;AAAA,IACF;AACA,WAAO,gBAAgB,SAAS,QAAQ;AAAA,EAC1C;AAEA,WAAS,OAAO;AAEhB,SAAO;AACT;","names":["UnsupportedFunctionalityError","UnsupportedFunctionalityError","response"]}
|
package/package.json
CHANGED
package/src/index.ts
CHANGED
@@ -1,6 +1,37 @@
|
|
1
1
|
import { WorkersAIChatLanguageModel } from "./workersai-chat-language-model";
|
2
2
|
import type { WorkersAIChatSettings } from "./workersai-chat-settings";
|
3
3
|
import type { TextGenerationModels } from "./workersai-models";
|
4
|
+
import { createRun } from "./utils";
|
5
|
+
|
6
|
+
export type WorkersAISettings =
|
7
|
+
({
|
8
|
+
/**
|
9
|
+
* Provide a Cloudflare AI binding.
|
10
|
+
*/
|
11
|
+
binding: Ai;
|
12
|
+
|
13
|
+
/**
|
14
|
+
* Credentials must be absent when a binding is given.
|
15
|
+
*/
|
16
|
+
accountId?: never;
|
17
|
+
apiKey?: never;
|
18
|
+
}
|
19
|
+
| {
|
20
|
+
/**
|
21
|
+
* Provide Cloudflare API credentials directly. Must be used if a binding is not specified.
|
22
|
+
*/
|
23
|
+
accountId: string;
|
24
|
+
apiKey: string;
|
25
|
+
/**
|
26
|
+
* Both binding must be absent if credentials are used directly.
|
27
|
+
*/
|
28
|
+
binding?: never;
|
29
|
+
}) & {
|
30
|
+
/**
|
31
|
+
* Optionally specify a gateway.
|
32
|
+
*/
|
33
|
+
gateway?: GatewayOptions;
|
34
|
+
};
|
4
35
|
|
5
36
|
export interface WorkersAI {
|
6
37
|
(
|
@@ -17,36 +48,40 @@ export interface WorkersAI {
|
|
17
48
|
): WorkersAIChatLanguageModel;
|
18
49
|
}
|
19
50
|
|
20
|
-
export interface WorkersAISettings {
|
21
|
-
/**
|
22
|
-
* Provide an `env.AI` binding to use for the AI inference.
|
23
|
-
* You can set up an AI bindings in your Workers project
|
24
|
-
* by adding the following this to `wrangler.toml`:
|
25
|
-
|
26
|
-
```toml
|
27
|
-
[ai]
|
28
|
-
binding = "AI"
|
29
|
-
```
|
30
|
-
**/
|
31
|
-
binding: Ai;
|
32
|
-
/**
|
33
|
-
* Optionally set Cloudflare AI Gateway options.
|
34
|
-
*/
|
35
|
-
gateway?: GatewayOptions;
|
36
|
-
}
|
37
|
-
|
38
51
|
/**
|
39
52
|
* Create a Workers AI provider instance.
|
40
|
-
|
53
|
+
*/
|
41
54
|
export function createWorkersAI(options: WorkersAISettings): WorkersAI {
|
55
|
+
// Use a binding if one is directly provided. Otherwise use credentials to create
|
56
|
+
// a `run` method that calls the Cloudflare REST API.
|
57
|
+
let binding: Ai | undefined;
|
58
|
+
|
59
|
+
if (options.binding) {
|
60
|
+
binding = options.binding;
|
61
|
+
} else {
|
62
|
+
const { accountId, apiKey } = options;
|
63
|
+
binding = {
|
64
|
+
run: createRun(accountId, apiKey),
|
65
|
+
} as Ai;
|
66
|
+
}
|
67
|
+
|
68
|
+
if (!binding) {
|
69
|
+
throw new Error(
|
70
|
+
"Either a binding or credentials must be provided."
|
71
|
+
);
|
72
|
+
}
|
73
|
+
|
74
|
+
/**
|
75
|
+
* Helper function to create a chat model instance.
|
76
|
+
*/
|
42
77
|
const createChatModel = (
|
43
78
|
modelId: TextGenerationModels,
|
44
79
|
settings: WorkersAIChatSettings = {}
|
45
80
|
) =>
|
46
81
|
new WorkersAIChatLanguageModel(modelId, settings, {
|
47
82
|
provider: "workersai.chat",
|
48
|
-
binding
|
49
|
-
gateway: options.gateway
|
83
|
+
binding,
|
84
|
+
gateway: options.gateway
|
50
85
|
});
|
51
86
|
|
52
87
|
const provider = function (
|
@@ -58,7 +93,6 @@ export function createWorkersAI(options: WorkersAISettings): WorkersAI {
|
|
58
93
|
"The WorkersAI model function cannot be called with the new keyword."
|
59
94
|
);
|
60
95
|
}
|
61
|
-
|
62
96
|
return createChatModel(modelId, settings);
|
63
97
|
};
|
64
98
|
|
@@ -66,3 +100,4 @@ export function createWorkersAI(options: WorkersAISettings): WorkersAI {
|
|
66
100
|
|
67
101
|
return provider;
|
68
102
|
}
|
103
|
+
|
package/src/utils.ts
ADDED
@@ -0,0 +1,64 @@
|
|
1
|
+
/**
|
2
|
+
* Creates a run method that mimics the Cloudflare Workers AI binding,
|
3
|
+
* but uses the Cloudflare REST API under the hood.
|
4
|
+
*
|
5
|
+
* @param accountId - Your Cloudflare account identifier.
|
6
|
+
* @param apiKey - Your Cloudflare API token/key with appropriate permissions.
|
7
|
+
* @returns An function matching `Ai['run']`.
|
8
|
+
*/
|
9
|
+
export function createRun(accountId: string, apiKey: string): AiRun {
|
10
|
+
return async <Name extends keyof AiModels>(model: Name, inputs: AiModels[Name]["inputs"], options?: AiOptions | undefined) => {
|
11
|
+
const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}`;
|
12
|
+
const body = JSON.stringify(inputs);
|
13
|
+
|
14
|
+
const headers = {
|
15
|
+
"Content-Type": "application/json",
|
16
|
+
Authorization: `Bearer ${apiKey}`,
|
17
|
+
};
|
18
|
+
|
19
|
+
const response = await fetch(url, {
|
20
|
+
method: "POST",
|
21
|
+
headers,
|
22
|
+
body,
|
23
|
+
}) as Response;
|
24
|
+
|
25
|
+
if (options?.returnRawResponse) {
|
26
|
+
return response;
|
27
|
+
}
|
28
|
+
|
29
|
+
if ((inputs as AiTextGenerationInput).stream === true) {
|
30
|
+
// If there's a stream, return the raw body so the caller can process it
|
31
|
+
if (response.body) {
|
32
|
+
return response.body;
|
33
|
+
}
|
34
|
+
throw new Error("No readable body available for streaming.");
|
35
|
+
}
|
36
|
+
|
37
|
+
// Otherwise, parse JSON and return the data.result
|
38
|
+
const data = await response.json<{ result: AiModels[Name]["postProcessedOutputs"] }>();
|
39
|
+
return data.result;
|
40
|
+
};
|
41
|
+
}
|
42
|
+
|
43
|
+
interface AiRun {
|
44
|
+
// (1) Return raw response if `options.returnRawResponse` is `true`.
|
45
|
+
<Name extends keyof AiModels>(
|
46
|
+
model: Name,
|
47
|
+
inputs: AiModels[Name]["inputs"],
|
48
|
+
options: AiOptions & { returnRawResponse: true }
|
49
|
+
): Promise<Response>;
|
50
|
+
|
51
|
+
// (2) Return a stream if the input has `stream: true`.
|
52
|
+
<Name extends keyof AiModels>(
|
53
|
+
model: Name,
|
54
|
+
inputs: AiModels[Name]["inputs"] & { stream: true },
|
55
|
+
options?: AiOptions
|
56
|
+
): Promise<ReadableStream<Uint8Array>>;
|
57
|
+
|
58
|
+
// (3) Return the post-processed outputs by default.
|
59
|
+
<Name extends keyof AiModels>(
|
60
|
+
model: Name,
|
61
|
+
inputs: AiModels[Name]["inputs"],
|
62
|
+
options?: AiOptions
|
63
|
+
): Promise<AiModels[Name]["postProcessedOutputs"]>;
|
64
|
+
}
|
@@ -11,6 +11,7 @@ import type { TextGenerationModels } from "./workersai-models";
|
|
11
11
|
|
12
12
|
import { events } from "fetch-event-stream";
|
13
13
|
import { mapWorkersAIUsage } from "./map-workersai-usage";
|
14
|
+
import type { WorkersAIChatPrompt } from "./workersai-chat-prompt";
|
14
15
|
|
15
16
|
type WorkersAIChatConfig = {
|
16
17
|
provider: string;
|
@@ -180,6 +181,47 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 {
|
|
180
181
|
): Promise<Awaited<ReturnType<LanguageModelV1["doStream"]>>> {
|
181
182
|
const { args, warnings } = this.getArgs(options);
|
182
183
|
|
184
|
+
// [1] When the latest message is not a tool response, we use the regular generate function
|
185
|
+
// and simulate it as a streamed response in order to satisfy the AI SDK's interface for
|
186
|
+
// doStream...
|
187
|
+
if (args.tools?.length && lastMessageWasUser(args.messages)) {
|
188
|
+
const response = await this.doGenerate(options);
|
189
|
+
|
190
|
+
if ((response instanceof ReadableStream)) {
|
191
|
+
throw new Error("This shouldn't happen");
|
192
|
+
}
|
193
|
+
|
194
|
+
return {
|
195
|
+
stream: new ReadableStream<LanguageModelV1StreamPart>({
|
196
|
+
async start(controller) {
|
197
|
+
if (response.text) {
|
198
|
+
controller.enqueue({
|
199
|
+
type: "text-delta",
|
200
|
+
textDelta: response.text,
|
201
|
+
})
|
202
|
+
}
|
203
|
+
if (response.toolCalls) {
|
204
|
+
for (const toolCall of response.toolCalls) {
|
205
|
+
controller.enqueue({
|
206
|
+
type: "tool-call",
|
207
|
+
...toolCall,
|
208
|
+
})
|
209
|
+
}
|
210
|
+
}
|
211
|
+
controller.enqueue({
|
212
|
+
type: "finish",
|
213
|
+
finishReason: "stop",
|
214
|
+
usage: response.usage,
|
215
|
+
});
|
216
|
+
controller.close();
|
217
|
+
},
|
218
|
+
}),
|
219
|
+
rawCall: { rawPrompt: args.messages, rawSettings: args },
|
220
|
+
warnings,
|
221
|
+
};
|
222
|
+
}
|
223
|
+
|
224
|
+
// [2] ...otherwise, we just proceed as normal and stream the response directly from the remote model.
|
183
225
|
const response = await this.config.binding.run(
|
184
226
|
args.model,
|
185
227
|
{
|
@@ -189,6 +231,8 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 {
|
|
189
231
|
temperature: args.temperature,
|
190
232
|
tools: args.tools,
|
191
233
|
top_p: args.top_p,
|
234
|
+
// @ts-expect-error response_format not yet added to types
|
235
|
+
response_format: args.response_format,
|
192
236
|
},
|
193
237
|
{ gateway: this.config.gateway ?? this.settings.gateway }
|
194
238
|
);
|
@@ -297,3 +341,7 @@ function prepareToolsAndToolChoice(
|
|
297
341
|
}
|
298
342
|
}
|
299
343
|
}
|
344
|
+
|
345
|
+
function lastMessageWasUser(messages: WorkersAIChatPrompt) {
|
346
|
+
return messages.length > 0 && messages[messages.length - 1].role === "user";
|
347
|
+
}
|