workers-ai-provider 0.1.2 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -58,6 +58,25 @@ export default {
58
58
  };
59
59
  ```
60
60
 
61
+ You can also use your Cloudflare credentials to create the provider, for example if you want to use Cloudflare AI outside of the Worker environment. For example, here is how you can use Cloudflare AI in a Node script:
62
+
63
+ ```js
64
+ const workersai = createWorkersAI({
65
+ accountId: process.env.CLOUDFLARE_ACCOUNT_ID,
66
+ apiKey: process.env.CLOUDFLARE_API_KEY
67
+ });
68
+
69
+ const text = await streamText({
70
+ model: workersai("@cf/meta/llama-2-7b-chat-int8"),
71
+ messages: [
72
+ {
73
+ role: "user",
74
+ content: "Write an essay about hello world",
75
+ },
76
+ ],
77
+ });
78
+ ```
79
+
61
80
  For more info, refer to the documentation of the [Vercel AI SDK](https://sdk.vercel.ai/).
62
81
 
63
82
  ### Credits
package/dist/index.d.ts CHANGED
@@ -40,6 +40,32 @@ declare class WorkersAIChatLanguageModel implements LanguageModelV1 {
40
40
  doStream(options: Parameters<LanguageModelV1["doStream"]>[0]): Promise<Awaited<ReturnType<LanguageModelV1["doStream"]>>>;
41
41
  }
42
42
 
43
+ type WorkersAISettings = ({
44
+ /**
45
+ * Provide a Cloudflare AI binding.
46
+ */
47
+ binding: Ai;
48
+ /**
49
+ * Credentials must be absent when a binding is given.
50
+ */
51
+ accountId?: never;
52
+ apiKey?: never;
53
+ } | {
54
+ /**
55
+ * Provide Cloudflare API credentials directly. Must be used if a binding is not specified.
56
+ */
57
+ accountId: string;
58
+ apiKey: string;
59
+ /**
60
+ * Both binding must be absent if credentials are used directly.
61
+ */
62
+ binding?: never;
63
+ }) & {
64
+ /**
65
+ * Optionally specify a gateway.
66
+ */
67
+ gateway?: GatewayOptions;
68
+ };
43
69
  interface WorkersAI {
44
70
  (modelId: TextGenerationModels, settings?: WorkersAIChatSettings): WorkersAIChatLanguageModel;
45
71
  /**
@@ -47,26 +73,9 @@ interface WorkersAI {
47
73
  **/
48
74
  chat(modelId: TextGenerationModels, settings?: WorkersAIChatSettings): WorkersAIChatLanguageModel;
49
75
  }
50
- interface WorkersAISettings {
51
- /**
52
- * Provide an `env.AI` binding to use for the AI inference.
53
- * You can set up an AI bindings in your Workers project
54
- * by adding the following this to `wrangler.toml`:
55
-
56
- ```toml
57
- [ai]
58
- binding = "AI"
59
- ```
60
- **/
61
- binding: Ai;
62
- /**
63
- * Optionally set Cloudflare AI Gateway options.
64
- */
65
- gateway?: GatewayOptions;
66
- }
67
76
  /**
68
77
  * Create a Workers AI provider instance.
69
- **/
78
+ */
70
79
  declare function createWorkersAI(options: WorkersAISettings): WorkersAI;
71
80
 
72
81
  export { type WorkersAI, type WorkersAISettings, createWorkersAI };
package/dist/index.js CHANGED
@@ -240,6 +240,40 @@ var WorkersAIChatLanguageModel = class {
240
240
  }
241
241
  async doStream(options) {
242
242
  const { args, warnings } = this.getArgs(options);
243
+ if (args.tools?.length && lastMessageWasUser(args.messages)) {
244
+ const response2 = await this.doGenerate(options);
245
+ if (response2 instanceof ReadableStream) {
246
+ throw new Error("This shouldn't happen");
247
+ }
248
+ return {
249
+ stream: new ReadableStream({
250
+ async start(controller) {
251
+ if (response2.text) {
252
+ controller.enqueue({
253
+ type: "text-delta",
254
+ textDelta: response2.text
255
+ });
256
+ }
257
+ if (response2.toolCalls) {
258
+ for (const toolCall of response2.toolCalls) {
259
+ controller.enqueue({
260
+ type: "tool-call",
261
+ ...toolCall
262
+ });
263
+ }
264
+ }
265
+ controller.enqueue({
266
+ type: "finish",
267
+ finishReason: "stop",
268
+ usage: response2.usage
269
+ });
270
+ controller.close();
271
+ }
272
+ }),
273
+ rawCall: { rawPrompt: args.messages, rawSettings: args },
274
+ warnings
275
+ };
276
+ }
243
277
  const response = await this.config.binding.run(
244
278
  args.model,
245
279
  {
@@ -248,7 +282,9 @@ var WorkersAIChatLanguageModel = class {
248
282
  stream: true,
249
283
  temperature: args.temperature,
250
284
  tools: args.tools,
251
- top_p: args.top_p
285
+ top_p: args.top_p,
286
+ // @ts-expect-error response_format not yet added to types
287
+ response_format: args.response_format
252
288
  },
253
289
  { gateway: this.config.gateway ?? this.settings.gateway }
254
290
  );
@@ -335,12 +371,57 @@ function prepareToolsAndToolChoice(mode) {
335
371
  }
336
372
  }
337
373
  }
374
+ function lastMessageWasUser(messages) {
375
+ return messages.length > 0 && messages[messages.length - 1].role === "user";
376
+ }
377
+
378
+ // src/utils.ts
379
+ function createRun(accountId, apiKey) {
380
+ return async (model, inputs, options) => {
381
+ const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}`;
382
+ const body = JSON.stringify(inputs);
383
+ const headers = {
384
+ "Content-Type": "application/json",
385
+ Authorization: `Bearer ${apiKey}`
386
+ };
387
+ const response = await fetch(url, {
388
+ method: "POST",
389
+ headers,
390
+ body
391
+ });
392
+ if (options?.returnRawResponse) {
393
+ return response;
394
+ }
395
+ if (inputs.stream === true) {
396
+ if (response.body) {
397
+ return response.body;
398
+ }
399
+ throw new Error("No readable body available for streaming.");
400
+ }
401
+ const data = await response.json();
402
+ return data.result;
403
+ };
404
+ }
338
405
 
339
406
  // src/index.ts
340
407
  function createWorkersAI(options) {
408
+ let binding;
409
+ if (options.binding) {
410
+ binding = options.binding;
411
+ } else {
412
+ const { accountId, apiKey } = options;
413
+ binding = {
414
+ run: createRun(accountId, apiKey)
415
+ };
416
+ }
417
+ if (!binding) {
418
+ throw new Error(
419
+ "Either a binding or credentials must be provided."
420
+ );
421
+ }
341
422
  const createChatModel = (modelId, settings = {}) => new WorkersAIChatLanguageModel(modelId, settings, {
342
423
  provider: "workersai.chat",
343
- binding: options.binding,
424
+ binding,
344
425
  gateway: options.gateway
345
426
  });
346
427
  const provider = function(modelId, settings) {
package/dist/index.js.map CHANGED
@@ -1 +1 @@
1
- {"version":3,"sources":["../src/workersai-chat-language-model.ts","../src/convert-to-workersai-chat-messages.ts","../src/map-workersai-usage.ts","../src/index.ts"],"sourcesContent":["import {\n type LanguageModelV1,\n type LanguageModelV1CallWarning,\n type LanguageModelV1StreamPart,\n UnsupportedFunctionalityError,\n} from \"@ai-sdk/provider\";\nimport { z } from \"zod\";\nimport { convertToWorkersAIChatMessages } from \"./convert-to-workersai-chat-messages\";\nimport type { WorkersAIChatSettings } from \"./workersai-chat-settings\";\nimport type { TextGenerationModels } from \"./workersai-models\";\n\nimport { events } from \"fetch-event-stream\";\nimport { mapWorkersAIUsage } from \"./map-workersai-usage\";\n\ntype WorkersAIChatConfig = {\n provider: string;\n binding: Ai;\n gateway?: GatewayOptions;\n};\n\nexport class WorkersAIChatLanguageModel implements LanguageModelV1 {\n readonly specificationVersion = \"v1\";\n readonly defaultObjectGenerationMode = \"json\";\n\n readonly modelId: TextGenerationModels;\n readonly settings: WorkersAIChatSettings;\n\n private readonly config: WorkersAIChatConfig;\n\n constructor(\n modelId: TextGenerationModels,\n settings: WorkersAIChatSettings,\n config: WorkersAIChatConfig\n ) {\n this.modelId = modelId;\n this.settings = settings;\n this.config = config;\n }\n\n get provider(): string {\n return this.config.provider;\n }\n\n private getArgs({\n mode,\n prompt,\n maxTokens,\n temperature,\n topP,\n frequencyPenalty,\n presencePenalty,\n seed,\n }: Parameters<LanguageModelV1[\"doGenerate\"]>[0]) {\n const type = mode.type;\n\n const warnings: LanguageModelV1CallWarning[] = [];\n\n if (frequencyPenalty != null) {\n warnings.push({\n type: \"unsupported-setting\",\n setting: \"frequencyPenalty\",\n });\n }\n\n if (presencePenalty != null) {\n warnings.push({\n type: \"unsupported-setting\",\n setting: \"presencePenalty\",\n });\n }\n\n const baseArgs = {\n // model id:\n model: this.modelId,\n\n // model specific settings:\n safe_prompt: this.settings.safePrompt,\n\n // standardized settings:\n max_tokens: maxTokens,\n temperature,\n top_p: topP,\n random_seed: seed,\n\n // messages:\n messages: convertToWorkersAIChatMessages(prompt),\n };\n\n switch (type) {\n case \"regular\": {\n return {\n args: { ...baseArgs, ...prepareToolsAndToolChoice(mode) },\n warnings,\n };\n }\n\n case \"object-json\": {\n return {\n args: {\n ...baseArgs,\n response_format: {\n type: \"json_schema\",\n json_schema: mode.schema,\n },\n tools: undefined,\n },\n warnings,\n };\n }\n\n case \"object-tool\": {\n return {\n args: {\n ...baseArgs,\n tool_choice: \"any\",\n tools: [{ type: \"function\", function: mode.tool }],\n },\n warnings,\n };\n }\n\n // @ts-expect-error - this is unreachable code\n // TODO: fixme\n case \"object-grammar\": {\n throw new UnsupportedFunctionalityError({\n functionality: \"object-grammar mode\",\n });\n }\n\n default: {\n const exhaustiveCheck = type satisfies never;\n throw new Error(`Unsupported type: ${exhaustiveCheck}`);\n }\n }\n }\n\n async doGenerate(\n options: Parameters<LanguageModelV1[\"doGenerate\"]>[0]\n ): Promise<Awaited<ReturnType<LanguageModelV1[\"doGenerate\"]>>> {\n const { args, warnings } = this.getArgs(options);\n\n const output = await this.config.binding.run(\n args.model,\n {\n messages: args.messages,\n max_tokens: args.max_tokens,\n temperature: args.temperature,\n tools: args.tools,\n top_p: args.top_p,\n // @ts-expect-error response_format not yet added to types\n response_format: args.response_format,\n },\n { gateway: this.config.gateway ?? this.settings.gateway }\n );\n\n if (output instanceof ReadableStream) {\n throw new Error(\"This shouldn't happen\");\n }\n\n return {\n text:\n typeof output.response === \"object\" && output.response !== null\n ? JSON.stringify(output.response) // ai-sdk expects a string here\n : output.response,\n toolCalls: output.tool_calls?.map((toolCall) => ({\n toolCallType: \"function\",\n toolCallId: toolCall.name,\n toolName: toolCall.name,\n args: JSON.stringify(toolCall.arguments || {}),\n })),\n finishReason: \"stop\", // TODO: mapWorkersAIFinishReason(response.finish_reason),\n rawCall: { rawPrompt: args.messages, rawSettings: args },\n usage: mapWorkersAIUsage(output),\n warnings,\n };\n }\n\n async doStream(\n options: Parameters<LanguageModelV1[\"doStream\"]>[0]\n ): Promise<Awaited<ReturnType<LanguageModelV1[\"doStream\"]>>> {\n const { args, warnings } = this.getArgs(options);\n\n const response = await this.config.binding.run(\n args.model,\n {\n messages: args.messages,\n max_tokens: args.max_tokens,\n stream: true,\n temperature: args.temperature,\n tools: args.tools,\n top_p: args.top_p,\n },\n { gateway: this.config.gateway ?? this.settings.gateway }\n );\n\n if (!(response instanceof ReadableStream)) {\n throw new Error(\"This shouldn't happen\");\n }\n\n const chunkEvent = events(new Response(response));\n let usage = { promptTokens: 0, completionTokens: 0 };\n\n return {\n stream: new ReadableStream<LanguageModelV1StreamPart>({\n async start(controller) {\n for await (const event of chunkEvent) {\n if (!event.data) {\n continue;\n }\n if (event.data === \"[DONE]\") {\n break;\n }\n const chunk = JSON.parse(event.data);\n if (chunk.usage) {\n usage = mapWorkersAIUsage(chunk);\n }\n chunk.response.length &&\n controller.enqueue({\n type: \"text-delta\",\n textDelta: chunk.response,\n });\n }\n controller.enqueue({\n type: \"finish\",\n finishReason: \"stop\",\n usage: usage,\n });\n controller.close();\n },\n }),\n rawCall: { rawPrompt: args.messages, rawSettings: args },\n warnings,\n };\n }\n}\n// limited version of the schema, focussed on what is needed for the implementation\n// this approach limits breakages when the API changes and increases efficiency\nconst workersAIChatResponseSchema = z.object({\n response: z.string(),\n});\n\n// limited version of the schema, focussed on what is needed for the implementation\n// this approach limits breakages when the API changes and increases efficiency\nconst workersAIChatChunkSchema = z.instanceof(Uint8Array);\n\nfunction prepareToolsAndToolChoice(\n mode: Parameters<LanguageModelV1[\"doGenerate\"]>[0][\"mode\"] & {\n type: \"regular\";\n }\n) {\n // when the tools array is empty, change it to undefined to prevent errors:\n const tools = mode.tools?.length ? mode.tools : undefined;\n\n if (tools == null) {\n return { tools: undefined, tool_choice: undefined };\n }\n\n const mappedTools = tools.map((tool) => ({\n type: \"function\",\n function: {\n name: tool.name,\n // @ts-expect-error - description is not a property of tool\n description: tool.description,\n // @ts-expect-error - parameters is not a property of tool\n parameters: tool.parameters,\n },\n }));\n\n const toolChoice = mode.toolChoice;\n\n if (toolChoice == null) {\n return { tools: mappedTools, tool_choice: undefined };\n }\n\n const type = toolChoice.type;\n\n switch (type) {\n case \"auto\":\n return { tools: mappedTools, tool_choice: type };\n case \"none\":\n return { tools: mappedTools, tool_choice: type };\n case \"required\":\n return { tools: mappedTools, tool_choice: \"any\" };\n\n // workersAI does not support tool mode directly,\n // so we filter the tools and force the tool choice through 'any'\n case \"tool\":\n return {\n tools: mappedTools.filter(\n (tool) => tool.function.name === toolChoice.toolName\n ),\n tool_choice: \"any\",\n };\n default: {\n const exhaustiveCheck = type satisfies never;\n throw new Error(`Unsupported tool choice type: ${exhaustiveCheck}`);\n }\n }\n}\n","import {\n type LanguageModelV1Prompt,\n UnsupportedFunctionalityError,\n} from \"@ai-sdk/provider\";\nimport type { WorkersAIChatPrompt } from \"./workersai-chat-prompt\";\n\n// TODO\nexport function convertToWorkersAIChatMessages(\n prompt: LanguageModelV1Prompt\n): WorkersAIChatPrompt {\n const messages: WorkersAIChatPrompt = [];\n\n for (const { role, content } of prompt) {\n switch (role) {\n case \"system\": {\n messages.push({ role: \"system\", content });\n break;\n }\n\n case \"user\": {\n messages.push({\n role: \"user\",\n content: content\n .map((part) => {\n switch (part.type) {\n case \"text\": {\n return part.text;\n }\n case \"image\": {\n throw new UnsupportedFunctionalityError({\n functionality: \"image-part\",\n });\n }\n }\n })\n .join(\"\"),\n });\n break;\n }\n\n case \"assistant\": {\n let text = \"\";\n const toolCalls: Array<{\n id: string;\n type: \"function\";\n function: { name: string; arguments: string };\n }> = [];\n\n for (const part of content) {\n switch (part.type) {\n case \"text\": {\n text += part.text;\n break;\n }\n case \"tool-call\": {\n text = JSON.stringify({ name: part.toolName, parameters: part.args });\n\n toolCalls.push({\n id: part.toolCallId,\n type: \"function\",\n function: {\n name: part.toolName,\n arguments: JSON.stringify(part.args),\n },\n });\n break;\n }\n default: {\n const exhaustiveCheck = part satisfies never;\n throw new Error(`Unsupported part: ${exhaustiveCheck}`);\n }\n }\n }\n\n messages.push({\n role: \"assistant\",\n content: text,\n tool_calls:\n toolCalls.length > 0\n ? toolCalls.map(({ function: { name, arguments: args } }) => ({\n id: \"null\",\n type: \"function\",\n function: { name, arguments: args },\n }))\n : undefined,\n });\n\n break;\n }\n case \"tool\": {\n for (const toolResponse of content) {\n messages.push({\n role: \"tool\",\n name: toolResponse.toolName,\n content: JSON.stringify(toolResponse.result),\n });\n }\n break;\n }\n default: {\n const exhaustiveCheck = role satisfies never;\n throw new Error(`Unsupported role: ${exhaustiveCheck}`);\n }\n }\n }\n\n return messages;\n}\n","export function mapWorkersAIUsage(output: AiTextGenerationOutput | AiTextToImageOutput) {\n\tconst usage = (output as { usage: { prompt_tokens: number, completion_tokens: number} }).usage ?? {\n\t\tprompt_tokens: 0,\n\t\tcompletion_tokens: 0,\n\t};\n\n\treturn {\n\t\tpromptTokens: usage.prompt_tokens,\n\t\tcompletionTokens: usage.completion_tokens\n\t}\n}\n","import { WorkersAIChatLanguageModel } from \"./workersai-chat-language-model\";\nimport type { WorkersAIChatSettings } from \"./workersai-chat-settings\";\nimport type { TextGenerationModels } from \"./workersai-models\";\n\nexport interface WorkersAI {\n (\n modelId: TextGenerationModels,\n settings?: WorkersAIChatSettings\n ): WorkersAIChatLanguageModel;\n\n /**\n * Creates a model for text generation.\n **/\n chat(\n modelId: TextGenerationModels,\n settings?: WorkersAIChatSettings\n ): WorkersAIChatLanguageModel;\n}\n\nexport interface WorkersAISettings {\n /**\n * Provide an `env.AI` binding to use for the AI inference.\n * You can set up an AI bindings in your Workers project\n * by adding the following this to `wrangler.toml`:\n\n ```toml\n[ai]\nbinding = \"AI\"\n ```\n **/\n binding: Ai;\n /**\n * Optionally set Cloudflare AI Gateway options.\n */\n gateway?: GatewayOptions;\n}\n\n/**\n * Create a Workers AI provider instance.\n **/\nexport function createWorkersAI(options: WorkersAISettings): WorkersAI {\n const createChatModel = (\n modelId: TextGenerationModels,\n settings: WorkersAIChatSettings = {}\n ) =>\n new WorkersAIChatLanguageModel(modelId, settings, {\n provider: \"workersai.chat\",\n binding: options.binding,\n gateway: options.gateway,\n });\n\n const provider = function (\n modelId: TextGenerationModels,\n settings?: WorkersAIChatSettings\n ) {\n if (new.target) {\n throw new Error(\n \"The WorkersAI model function cannot be called with the new keyword.\"\n );\n }\n\n return createChatModel(modelId, settings);\n };\n\n provider.chat = createChatModel;\n\n return provider;\n}\n"],"mappings":";;;;;AAAA;AAAA,EAIE,iCAAAA;AAAA,OACK;AACP,SAAS,SAAS;;;ACNlB;AAAA,EAEE;AAAA,OACK;AAIA,SAAS,+BACd,QACqB;AACrB,QAAM,WAAgC,CAAC;AAEvC,aAAW,EAAE,MAAM,QAAQ,KAAK,QAAQ;AACtC,YAAQ,MAAM;AAAA,MACZ,KAAK,UAAU;AACb,iBAAS,KAAK,EAAE,MAAM,UAAU,QAAQ,CAAC;AACzC;AAAA,MACF;AAAA,MAEA,KAAK,QAAQ;AACX,iBAAS,KAAK;AAAA,UACZ,MAAM;AAAA,UACN,SAAS,QACN,IAAI,CAAC,SAAS;AACb,oBAAQ,KAAK,MAAM;AAAA,cACjB,KAAK,QAAQ;AACX,uBAAO,KAAK;AAAA,cACd;AAAA,cACA,KAAK,SAAS;AACZ,sBAAM,IAAI,8BAA8B;AAAA,kBACtC,eAAe;AAAA,gBACjB,CAAC;AAAA,cACH;AAAA,YACF;AAAA,UACF,CAAC,EACA,KAAK,EAAE;AAAA,QACZ,CAAC;AACD;AAAA,MACF;AAAA,MAEA,KAAK,aAAa;AAChB,YAAI,OAAO;AACX,cAAM,YAID,CAAC;AAEN,mBAAW,QAAQ,SAAS;AAC1B,kBAAQ,KAAK,MAAM;AAAA,YACjB,KAAK,QAAQ;AACX,sBAAQ,KAAK;AACb;AAAA,YACF;AAAA,YACA,KAAK,aAAa;AAChB,qBAAO,KAAK,UAAU,EAAE,MAAM,KAAK,UAAU,YAAY,KAAK,KAAK,CAAC;AAEpE,wBAAU,KAAK;AAAA,gBACb,IAAI,KAAK;AAAA,gBACT,MAAM;AAAA,gBACN,UAAU;AAAA,kBACR,MAAM,KAAK;AAAA,kBACX,WAAW,KAAK,UAAU,KAAK,IAAI;AAAA,gBACrC;AAAA,cACF,CAAC;AACD;AAAA,YACF;AAAA,YACA,SAAS;AACP,oBAAM,kBAAkB;AACxB,oBAAM,IAAI,MAAM,qBAAqB,eAAe,EAAE;AAAA,YACxD;AAAA,UACF;AAAA,QACF;AAEA,iBAAS,KAAK;AAAA,UACZ,MAAM;AAAA,UACN,SAAS;AAAA,UACT,YACE,UAAU,SAAS,IACf,UAAU,IAAI,CAAC,EAAE,UAAU,EAAE,MAAM,WAAW,KAAK,EAAE,OAAO;AAAA,YAC1D,IAAI;AAAA,YACJ,MAAM;AAAA,YACN,UAAU,EAAE,MAAM,WAAW,KAAK;AAAA,UACpC,EAAE,IACF;AAAA,QACR,CAAC;AAED;AAAA,MACF;AAAA,MACA,KAAK,QAAQ;AACX,mBAAW,gBAAgB,SAAS;AAClC,mBAAS,KAAK;AAAA,YACZ,MAAM;AAAA,YACN,MAAM,aAAa;AAAA,YACnB,SAAS,KAAK,UAAU,aAAa,MAAM;AAAA,UAC7C,CAAC;AAAA,QACH;AACA;AAAA,MACF;AAAA,MACA,SAAS;AACP,cAAM,kBAAkB;AACxB,cAAM,IAAI,MAAM,qBAAqB,eAAe,EAAE;AAAA,MACxD;AAAA,IACF;AAAA,EACF;AAEA,SAAO;AACT;;;ADhGA,SAAS,cAAc;;;AEXhB,SAAS,kBAAkB,QAAuD;AACxF,QAAM,QAAS,OAA0E,SAAS;AAAA,IACjG,eAAe;AAAA,IACf,mBAAmB;AAAA,EACpB;AAEA,SAAO;AAAA,IACN,cAAc,MAAM;AAAA,IACpB,kBAAkB,MAAM;AAAA,EACzB;AACD;;;AFUO,IAAM,6BAAN,MAA4D;AAAA,EASjE,YACE,SACA,UACA,QACA;AAZF,wBAAS,wBAAuB;AAChC,wBAAS,+BAA8B;AAEvC,wBAAS;AACT,wBAAS;AAET,wBAAiB;AAOf,SAAK,UAAU;AACf,SAAK,WAAW;AAChB,SAAK,SAAS;AAAA,EAChB;AAAA,EAEA,IAAI,WAAmB;AACrB,WAAO,KAAK,OAAO;AAAA,EACrB;AAAA,EAEQ,QAAQ;AAAA,IACd;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,EACF,GAAiD;AAC/C,UAAM,OAAO,KAAK;AAElB,UAAM,WAAyC,CAAC;AAEhD,QAAI,oBAAoB,MAAM;AAC5B,eAAS,KAAK;AAAA,QACZ,MAAM;AAAA,QACN,SAAS;AAAA,MACX,CAAC;AAAA,IACH;AAEA,QAAI,mBAAmB,MAAM;AAC3B,eAAS,KAAK;AAAA,QACZ,MAAM;AAAA,QACN,SAAS;AAAA,MACX,CAAC;AAAA,IACH;AAEA,UAAM,WAAW;AAAA;AAAA,MAEf,OAAO,KAAK;AAAA;AAAA,MAGZ,aAAa,KAAK,SAAS;AAAA;AAAA,MAG3B,YAAY;AAAA,MACZ;AAAA,MACA,OAAO;AAAA,MACP,aAAa;AAAA;AAAA,MAGb,UAAU,+BAA+B,MAAM;AAAA,IACjD;AAEA,YAAQ,MAAM;AAAA,MACZ,KAAK,WAAW;AACd,eAAO;AAAA,UACL,MAAM,EAAE,GAAG,UAAU,GAAG,0BAA0B,IAAI,EAAE;AAAA,UACxD;AAAA,QACF;AAAA,MACF;AAAA,MAEA,KAAK,eAAe;AAClB,eAAO;AAAA,UACL,MAAM;AAAA,YACJ,GAAG;AAAA,YACH,iBAAiB;AAAA,cACf,MAAM;AAAA,cACN,aAAa,KAAK;AAAA,YACpB;AAAA,YACA,OAAO;AAAA,UACT;AAAA,UACA;AAAA,QACF;AAAA,MACF;AAAA,MAEA,KAAK,eAAe;AAClB,eAAO;AAAA,UACL,MAAM;AAAA,YACJ,GAAG;AAAA,YACH,aAAa;AAAA,YACb,OAAO,CAAC,EAAE,MAAM,YAAY,UAAU,KAAK,KAAK,CAAC;AAAA,UACnD;AAAA,UACA;AAAA,QACF;AAAA,MACF;AAAA;AAAA;AAAA,MAIA,KAAK,kBAAkB;AACrB,cAAM,IAAIC,+BAA8B;AAAA,UACtC,eAAe;AAAA,QACjB,CAAC;AAAA,MACH;AAAA,MAEA,SAAS;AACP,cAAM,kBAAkB;AACxB,cAAM,IAAI,MAAM,qBAAqB,eAAe,EAAE;AAAA,MACxD;AAAA,IACF;AAAA,EACF;AAAA,EAEA,MAAM,WACJ,SAC6D;AAC7D,UAAM,EAAE,MAAM,SAAS,IAAI,KAAK,QAAQ,OAAO;AAE/C,UAAM,SAAS,MAAM,KAAK,OAAO,QAAQ;AAAA,MACvC,KAAK;AAAA,MACL;AAAA,QACE,UAAU,KAAK;AAAA,QACf,YAAY,KAAK;AAAA,QACjB,aAAa,KAAK;AAAA,QAClB,OAAO,KAAK;AAAA,QACZ,OAAO,KAAK;AAAA;AAAA,QAEZ,iBAAiB,KAAK;AAAA,MACxB;AAAA,MACA,EAAE,SAAS,KAAK,OAAO,WAAW,KAAK,SAAS,QAAQ;AAAA,IAC1D;AAEA,QAAI,kBAAkB,gBAAgB;AACpC,YAAM,IAAI,MAAM,uBAAuB;AAAA,IACzC;AAEA,WAAO;AAAA,MACL,MACE,OAAO,OAAO,aAAa,YAAY,OAAO,aAAa,OACvD,KAAK,UAAU,OAAO,QAAQ,IAC9B,OAAO;AAAA,MACb,WAAW,OAAO,YAAY,IAAI,CAAC,cAAc;AAAA,QAC/C,cAAc;AAAA,QACd,YAAY,SAAS;AAAA,QACrB,UAAU,SAAS;AAAA,QACnB,MAAM,KAAK,UAAU,SAAS,aAAa,CAAC,CAAC;AAAA,MAC/C,EAAE;AAAA,MACF,cAAc;AAAA;AAAA,MACd,SAAS,EAAE,WAAW,KAAK,UAAU,aAAa,KAAK;AAAA,MACvD,OAAO,kBAAkB,MAAM;AAAA,MAC/B;AAAA,IACF;AAAA,EACF;AAAA,EAEA,MAAM,SACJ,SAC2D;AAC3D,UAAM,EAAE,MAAM,SAAS,IAAI,KAAK,QAAQ,OAAO;AAE/C,UAAM,WAAW,MAAM,KAAK,OAAO,QAAQ;AAAA,MACzC,KAAK;AAAA,MACL;AAAA,QACE,UAAU,KAAK;AAAA,QACf,YAAY,KAAK;AAAA,QACjB,QAAQ;AAAA,QACR,aAAa,KAAK;AAAA,QAClB,OAAO,KAAK;AAAA,QACZ,OAAO,KAAK;AAAA,MACd;AAAA,MACA,EAAE,SAAS,KAAK,OAAO,WAAW,KAAK,SAAS,QAAQ;AAAA,IAC1D;AAEA,QAAI,EAAE,oBAAoB,iBAAiB;AACzC,YAAM,IAAI,MAAM,uBAAuB;AAAA,IACzC;AAEA,UAAM,aAAa,OAAO,IAAI,SAAS,QAAQ,CAAC;AAChD,QAAI,QAAQ,EAAE,cAAc,GAAG,kBAAkB,EAAE;AAEnD,WAAO;AAAA,MACL,QAAQ,IAAI,eAA0C;AAAA,QACpD,MAAM,MAAM,YAAY;AACtB,2BAAiB,SAAS,YAAY;AACpC,gBAAI,CAAC,MAAM,MAAM;AACf;AAAA,YACF;AACA,gBAAI,MAAM,SAAS,UAAU;AAC3B;AAAA,YACF;AACA,kBAAM,QAAQ,KAAK,MAAM,MAAM,IAAI;AACnC,gBAAI,MAAM,OAAO;AACf,sBAAQ,kBAAkB,KAAK;AAAA,YACjC;AACA,kBAAM,SAAS,UACb,WAAW,QAAQ;AAAA,cACjB,MAAM;AAAA,cACN,WAAW,MAAM;AAAA,YACnB,CAAC;AAAA,UACL;AACA,qBAAW,QAAQ;AAAA,YACjB,MAAM;AAAA,YACN,cAAc;AAAA,YACd;AAAA,UACF,CAAC;AACD,qBAAW,MAAM;AAAA,QACnB;AAAA,MACF,CAAC;AAAA,MACD,SAAS,EAAE,WAAW,KAAK,UAAU,aAAa,KAAK;AAAA,MACvD;AAAA,IACF;AAAA,EACF;AACF;AAGA,IAAM,8BAA8B,EAAE,OAAO;AAAA,EAC3C,UAAU,EAAE,OAAO;AACrB,CAAC;AAID,IAAM,2BAA2B,EAAE,WAAW,UAAU;AAExD,SAAS,0BACP,MAGA;AAEA,QAAM,QAAQ,KAAK,OAAO,SAAS,KAAK,QAAQ;AAEhD,MAAI,SAAS,MAAM;AACjB,WAAO,EAAE,OAAO,QAAW,aAAa,OAAU;AAAA,EACpD;AAEA,QAAM,cAAc,MAAM,IAAI,CAAC,UAAU;AAAA,IACvC,MAAM;AAAA,IACN,UAAU;AAAA,MACR,MAAM,KAAK;AAAA;AAAA,MAEX,aAAa,KAAK;AAAA;AAAA,MAElB,YAAY,KAAK;AAAA,IACnB;AAAA,EACF,EAAE;AAEF,QAAM,aAAa,KAAK;AAExB,MAAI,cAAc,MAAM;AACtB,WAAO,EAAE,OAAO,aAAa,aAAa,OAAU;AAAA,EACtD;AAEA,QAAM,OAAO,WAAW;AAExB,UAAQ,MAAM;AAAA,IACZ,KAAK;AACH,aAAO,EAAE,OAAO,aAAa,aAAa,KAAK;AAAA,IACjD,KAAK;AACH,aAAO,EAAE,OAAO,aAAa,aAAa,KAAK;AAAA,IACjD,KAAK;AACH,aAAO,EAAE,OAAO,aAAa,aAAa,MAAM;AAAA;AAAA;AAAA,IAIlD,KAAK;AACH,aAAO;AAAA,QACL,OAAO,YAAY;AAAA,UACjB,CAAC,SAAS,KAAK,SAAS,SAAS,WAAW;AAAA,QAC9C;AAAA,QACA,aAAa;AAAA,MACf;AAAA,IACF,SAAS;AACP,YAAM,kBAAkB;AACxB,YAAM,IAAI,MAAM,iCAAiC,eAAe,EAAE;AAAA,IACpE;AAAA,EACF;AACF;;;AGlQO,SAAS,gBAAgB,SAAuC;AACrE,QAAM,kBAAkB,CACtB,SACA,WAAkC,CAAC,MAEnC,IAAI,2BAA2B,SAAS,UAAU;AAAA,IAChD,UAAU;AAAA,IACV,SAAS,QAAQ;AAAA,IACjB,SAAS,QAAQ;AAAA,EACnB,CAAC;AAEH,QAAM,WAAW,SACf,SACA,UACA;AACA,QAAI,YAAY;AACd,YAAM,IAAI;AAAA,QACR;AAAA,MACF;AAAA,IACF;AAEA,WAAO,gBAAgB,SAAS,QAAQ;AAAA,EAC1C;AAEA,WAAS,OAAO;AAEhB,SAAO;AACT;","names":["UnsupportedFunctionalityError","UnsupportedFunctionalityError"]}
1
+ {"version":3,"sources":["../src/workersai-chat-language-model.ts","../src/convert-to-workersai-chat-messages.ts","../src/map-workersai-usage.ts","../src/utils.ts","../src/index.ts"],"sourcesContent":["import {\n type LanguageModelV1,\n type LanguageModelV1CallWarning,\n type LanguageModelV1StreamPart,\n UnsupportedFunctionalityError,\n} from \"@ai-sdk/provider\";\nimport { z } from \"zod\";\nimport { convertToWorkersAIChatMessages } from \"./convert-to-workersai-chat-messages\";\nimport type { WorkersAIChatSettings } from \"./workersai-chat-settings\";\nimport type { TextGenerationModels } from \"./workersai-models\";\n\nimport { events } from \"fetch-event-stream\";\nimport { mapWorkersAIUsage } from \"./map-workersai-usage\";\nimport type { WorkersAIChatPrompt } from \"./workersai-chat-prompt\";\n\ntype WorkersAIChatConfig = {\n provider: string;\n binding: Ai;\n gateway?: GatewayOptions;\n};\n\nexport class WorkersAIChatLanguageModel implements LanguageModelV1 {\n readonly specificationVersion = \"v1\";\n readonly defaultObjectGenerationMode = \"json\";\n\n readonly modelId: TextGenerationModels;\n readonly settings: WorkersAIChatSettings;\n\n private readonly config: WorkersAIChatConfig;\n\n constructor(\n modelId: TextGenerationModels,\n settings: WorkersAIChatSettings,\n config: WorkersAIChatConfig\n ) {\n this.modelId = modelId;\n this.settings = settings;\n this.config = config;\n }\n\n get provider(): string {\n return this.config.provider;\n }\n\n private getArgs({\n mode,\n prompt,\n maxTokens,\n temperature,\n topP,\n frequencyPenalty,\n presencePenalty,\n seed,\n }: Parameters<LanguageModelV1[\"doGenerate\"]>[0]) {\n const type = mode.type;\n\n const warnings: LanguageModelV1CallWarning[] = [];\n\n if (frequencyPenalty != null) {\n warnings.push({\n type: \"unsupported-setting\",\n setting: \"frequencyPenalty\",\n });\n }\n\n if (presencePenalty != null) {\n warnings.push({\n type: \"unsupported-setting\",\n setting: \"presencePenalty\",\n });\n }\n\n const baseArgs = {\n // model id:\n model: this.modelId,\n\n // model specific settings:\n safe_prompt: this.settings.safePrompt,\n\n // standardized settings:\n max_tokens: maxTokens,\n temperature,\n top_p: topP,\n random_seed: seed,\n\n // messages:\n messages: convertToWorkersAIChatMessages(prompt),\n };\n\n switch (type) {\n case \"regular\": {\n return {\n args: { ...baseArgs, ...prepareToolsAndToolChoice(mode) },\n warnings,\n };\n }\n\n case \"object-json\": {\n return {\n args: {\n ...baseArgs,\n response_format: {\n type: \"json_schema\",\n json_schema: mode.schema,\n },\n tools: undefined,\n },\n warnings,\n };\n }\n\n case \"object-tool\": {\n return {\n args: {\n ...baseArgs,\n tool_choice: \"any\",\n tools: [{ type: \"function\", function: mode.tool }],\n },\n warnings,\n };\n }\n\n // @ts-expect-error - this is unreachable code\n // TODO: fixme\n case \"object-grammar\": {\n throw new UnsupportedFunctionalityError({\n functionality: \"object-grammar mode\",\n });\n }\n\n default: {\n const exhaustiveCheck = type satisfies never;\n throw new Error(`Unsupported type: ${exhaustiveCheck}`);\n }\n }\n }\n\n async doGenerate(\n options: Parameters<LanguageModelV1[\"doGenerate\"]>[0]\n ): Promise<Awaited<ReturnType<LanguageModelV1[\"doGenerate\"]>>> {\n const { args, warnings } = this.getArgs(options);\n\n const output = await this.config.binding.run(\n args.model,\n {\n messages: args.messages,\n max_tokens: args.max_tokens,\n temperature: args.temperature,\n tools: args.tools,\n top_p: args.top_p,\n // @ts-expect-error response_format not yet added to types\n response_format: args.response_format,\n },\n { gateway: this.config.gateway ?? this.settings.gateway }\n );\n\n if (output instanceof ReadableStream) {\n throw new Error(\"This shouldn't happen\");\n }\n\n return {\n text:\n typeof output.response === \"object\" && output.response !== null\n ? JSON.stringify(output.response) // ai-sdk expects a string here\n : output.response,\n toolCalls: output.tool_calls?.map((toolCall) => ({\n toolCallType: \"function\",\n toolCallId: toolCall.name,\n toolName: toolCall.name,\n args: JSON.stringify(toolCall.arguments || {}),\n })),\n finishReason: \"stop\", // TODO: mapWorkersAIFinishReason(response.finish_reason),\n rawCall: { rawPrompt: args.messages, rawSettings: args },\n usage: mapWorkersAIUsage(output),\n warnings,\n };\n }\n\n async doStream(\n options: Parameters<LanguageModelV1[\"doStream\"]>[0]\n ): Promise<Awaited<ReturnType<LanguageModelV1[\"doStream\"]>>> {\n const { args, warnings } = this.getArgs(options);\n\n // [1] When the latest message is not a tool response, we use the regular generate function\n // and simulate it as a streamed response in order to satisfy the AI SDK's interface for\n // doStream...\n if (args.tools?.length && lastMessageWasUser(args.messages)) {\n const response = await this.doGenerate(options);\n\n if ((response instanceof ReadableStream)) {\n throw new Error(\"This shouldn't happen\");\n }\n\n return {\n stream: new ReadableStream<LanguageModelV1StreamPart>({\n async start(controller) {\n if (response.text) {\n controller.enqueue({\n type: \"text-delta\",\n textDelta: response.text,\n })\n }\n if (response.toolCalls) {\n for (const toolCall of response.toolCalls) {\n controller.enqueue({\n type: \"tool-call\",\n ...toolCall,\n })\n }\n }\n controller.enqueue({\n type: \"finish\",\n finishReason: \"stop\",\n usage: response.usage,\n });\n controller.close();\n },\n }),\n rawCall: { rawPrompt: args.messages, rawSettings: args },\n warnings,\n };\n }\n\n // [2] ...otherwise, we just proceed as normal and stream the response directly from the remote model.\n const response = await this.config.binding.run(\n args.model,\n {\n messages: args.messages,\n max_tokens: args.max_tokens,\n stream: true,\n temperature: args.temperature,\n tools: args.tools,\n top_p: args.top_p,\n // @ts-expect-error response_format not yet added to types\n response_format: args.response_format,\n },\n { gateway: this.config.gateway ?? this.settings.gateway }\n );\n\n if (!(response instanceof ReadableStream)) {\n throw new Error(\"This shouldn't happen\");\n }\n\n const chunkEvent = events(new Response(response));\n let usage = { promptTokens: 0, completionTokens: 0 };\n\n return {\n stream: new ReadableStream<LanguageModelV1StreamPart>({\n async start(controller) {\n for await (const event of chunkEvent) {\n if (!event.data) {\n continue;\n }\n if (event.data === \"[DONE]\") {\n break;\n }\n const chunk = JSON.parse(event.data);\n if (chunk.usage) {\n usage = mapWorkersAIUsage(chunk);\n }\n chunk.response.length &&\n controller.enqueue({\n type: \"text-delta\",\n textDelta: chunk.response,\n });\n }\n controller.enqueue({\n type: \"finish\",\n finishReason: \"stop\",\n usage: usage,\n });\n controller.close();\n },\n }),\n rawCall: { rawPrompt: args.messages, rawSettings: args },\n warnings,\n };\n }\n}\n// limited version of the schema, focussed on what is needed for the implementation\n// this approach limits breakages when the API changes and increases efficiency\nconst workersAIChatResponseSchema = z.object({\n response: z.string(),\n});\n\n// limited version of the schema, focussed on what is needed for the implementation\n// this approach limits breakages when the API changes and increases efficiency\nconst workersAIChatChunkSchema = z.instanceof(Uint8Array);\n\nfunction prepareToolsAndToolChoice(\n mode: Parameters<LanguageModelV1[\"doGenerate\"]>[0][\"mode\"] & {\n type: \"regular\";\n }\n) {\n // when the tools array is empty, change it to undefined to prevent errors:\n const tools = mode.tools?.length ? mode.tools : undefined;\n\n if (tools == null) {\n return { tools: undefined, tool_choice: undefined };\n }\n\n const mappedTools = tools.map((tool) => ({\n type: \"function\",\n function: {\n name: tool.name,\n // @ts-expect-error - description is not a property of tool\n description: tool.description,\n // @ts-expect-error - parameters is not a property of tool\n parameters: tool.parameters,\n },\n }));\n\n const toolChoice = mode.toolChoice;\n\n if (toolChoice == null) {\n return { tools: mappedTools, tool_choice: undefined };\n }\n\n const type = toolChoice.type;\n\n switch (type) {\n case \"auto\":\n return { tools: mappedTools, tool_choice: type };\n case \"none\":\n return { tools: mappedTools, tool_choice: type };\n case \"required\":\n return { tools: mappedTools, tool_choice: \"any\" };\n\n // workersAI does not support tool mode directly,\n // so we filter the tools and force the tool choice through 'any'\n case \"tool\":\n return {\n tools: mappedTools.filter(\n (tool) => tool.function.name === toolChoice.toolName\n ),\n tool_choice: \"any\",\n };\n default: {\n const exhaustiveCheck = type satisfies never;\n throw new Error(`Unsupported tool choice type: ${exhaustiveCheck}`);\n }\n }\n}\n\nfunction lastMessageWasUser(messages: WorkersAIChatPrompt) {\n return messages.length > 0 && messages[messages.length - 1].role === \"user\";\n}\n","import {\n type LanguageModelV1Prompt,\n UnsupportedFunctionalityError,\n} from \"@ai-sdk/provider\";\nimport type { WorkersAIChatPrompt } from \"./workersai-chat-prompt\";\n\n// TODO\nexport function convertToWorkersAIChatMessages(\n prompt: LanguageModelV1Prompt\n): WorkersAIChatPrompt {\n const messages: WorkersAIChatPrompt = [];\n\n for (const { role, content } of prompt) {\n switch (role) {\n case \"system\": {\n messages.push({ role: \"system\", content });\n break;\n }\n\n case \"user\": {\n messages.push({\n role: \"user\",\n content: content\n .map((part) => {\n switch (part.type) {\n case \"text\": {\n return part.text;\n }\n case \"image\": {\n throw new UnsupportedFunctionalityError({\n functionality: \"image-part\",\n });\n }\n }\n })\n .join(\"\"),\n });\n break;\n }\n\n case \"assistant\": {\n let text = \"\";\n const toolCalls: Array<{\n id: string;\n type: \"function\";\n function: { name: string; arguments: string };\n }> = [];\n\n for (const part of content) {\n switch (part.type) {\n case \"text\": {\n text += part.text;\n break;\n }\n case \"tool-call\": {\n text = JSON.stringify({ name: part.toolName, parameters: part.args });\n\n toolCalls.push({\n id: part.toolCallId,\n type: \"function\",\n function: {\n name: part.toolName,\n arguments: JSON.stringify(part.args),\n },\n });\n break;\n }\n default: {\n const exhaustiveCheck = part satisfies never;\n throw new Error(`Unsupported part: ${exhaustiveCheck}`);\n }\n }\n }\n\n messages.push({\n role: \"assistant\",\n content: text,\n tool_calls:\n toolCalls.length > 0\n ? toolCalls.map(({ function: { name, arguments: args } }) => ({\n id: \"null\",\n type: \"function\",\n function: { name, arguments: args },\n }))\n : undefined,\n });\n\n break;\n }\n case \"tool\": {\n for (const toolResponse of content) {\n messages.push({\n role: \"tool\",\n name: toolResponse.toolName,\n content: JSON.stringify(toolResponse.result),\n });\n }\n break;\n }\n default: {\n const exhaustiveCheck = role satisfies never;\n throw new Error(`Unsupported role: ${exhaustiveCheck}`);\n }\n }\n }\n\n return messages;\n}\n","export function mapWorkersAIUsage(output: AiTextGenerationOutput | AiTextToImageOutput) {\n\tconst usage = (output as { usage: { prompt_tokens: number, completion_tokens: number} }).usage ?? {\n\t\tprompt_tokens: 0,\n\t\tcompletion_tokens: 0,\n\t};\n\n\treturn {\n\t\tpromptTokens: usage.prompt_tokens,\n\t\tcompletionTokens: usage.completion_tokens\n\t}\n}\n","/**\n * Creates a run method that mimics the Cloudflare Workers AI binding,\n * but uses the Cloudflare REST API under the hood.\n *\n * @param accountId - Your Cloudflare account identifier.\n * @param apiKey - Your Cloudflare API token/key with appropriate permissions.\n * @returns An function matching `Ai['run']`.\n */\nexport function createRun(accountId: string, apiKey: string): AiRun {\n return async <Name extends keyof AiModels>(model: Name, inputs: AiModels[Name][\"inputs\"], options?: AiOptions | undefined) => {\n const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}`;\n const body = JSON.stringify(inputs);\n\n const headers = {\n \"Content-Type\": \"application/json\",\n Authorization: `Bearer ${apiKey}`,\n };\n\n const response = await fetch(url, {\n method: \"POST\",\n headers,\n body,\n }) as Response;\n\n if (options?.returnRawResponse) {\n return response;\n }\n\n if ((inputs as AiTextGenerationInput).stream === true) {\n // If there's a stream, return the raw body so the caller can process it\n if (response.body) {\n return response.body;\n }\n throw new Error(\"No readable body available for streaming.\");\n }\n\n // Otherwise, parse JSON and return the data.result\n const data = await response.json<{ result: AiModels[Name][\"postProcessedOutputs\"] }>();\n return data.result;\n };\n}\n\ninterface AiRun {\n // (1) Return raw response if `options.returnRawResponse` is `true`.\n <Name extends keyof AiModels>(\n model: Name,\n inputs: AiModels[Name][\"inputs\"],\n options: AiOptions & { returnRawResponse: true }\n ): Promise<Response>;\n\n // (2) Return a stream if the input has `stream: true`.\n <Name extends keyof AiModels>(\n model: Name,\n inputs: AiModels[Name][\"inputs\"] & { stream: true },\n options?: AiOptions\n ): Promise<ReadableStream<Uint8Array>>;\n\n // (3) Return the post-processed outputs by default.\n <Name extends keyof AiModels>(\n model: Name,\n inputs: AiModels[Name][\"inputs\"],\n options?: AiOptions\n ): Promise<AiModels[Name][\"postProcessedOutputs\"]>;\n}\n","import { WorkersAIChatLanguageModel } from \"./workersai-chat-language-model\";\nimport type { WorkersAIChatSettings } from \"./workersai-chat-settings\";\nimport type { TextGenerationModels } from \"./workersai-models\";\nimport { createRun } from \"./utils\";\n\nexport type WorkersAISettings =\n ({\n /**\n * Provide a Cloudflare AI binding.\n */\n binding: Ai;\n\n /**\n * Credentials must be absent when a binding is given.\n */\n accountId?: never;\n apiKey?: never;\n}\n | {\n /**\n * Provide Cloudflare API credentials directly. Must be used if a binding is not specified.\n */\n accountId: string;\n apiKey: string;\n /**\n * Both binding must be absent if credentials are used directly.\n */\n binding?: never;\n}) & {\n /**\n * Optionally specify a gateway.\n */\n gateway?: GatewayOptions;\n};\n\nexport interface WorkersAI {\n (\n modelId: TextGenerationModels,\n settings?: WorkersAIChatSettings\n ): WorkersAIChatLanguageModel;\n\n /**\n * Creates a model for text generation.\n **/\n chat(\n modelId: TextGenerationModels,\n settings?: WorkersAIChatSettings\n ): WorkersAIChatLanguageModel;\n}\n\n/**\n * Create a Workers AI provider instance.\n */\nexport function createWorkersAI(options: WorkersAISettings): WorkersAI {\n // Use a binding if one is directly provided. Otherwise use credentials to create\n // a `run` method that calls the Cloudflare REST API.\n let binding: Ai | undefined;\n\n if (options.binding) {\n binding = options.binding;\n } else {\n const { accountId, apiKey } = options;\n binding = {\n run: createRun(accountId, apiKey),\n } as Ai;\n }\n\n if (!binding) {\n throw new Error(\n \"Either a binding or credentials must be provided.\"\n );\n }\n\n /**\n * Helper function to create a chat model instance.\n */\n const createChatModel = (\n modelId: TextGenerationModels,\n settings: WorkersAIChatSettings = {}\n ) =>\n new WorkersAIChatLanguageModel(modelId, settings, {\n provider: \"workersai.chat\",\n binding,\n gateway: options.gateway\n });\n\n const provider = function (\n modelId: TextGenerationModels,\n settings?: WorkersAIChatSettings\n ) {\n if (new.target) {\n throw new Error(\n \"The WorkersAI model function cannot be called with the new keyword.\"\n );\n }\n return createChatModel(modelId, settings);\n };\n\n provider.chat = createChatModel;\n\n return provider;\n}\n\n"],"mappings":";;;;;AAAA;AAAA,EAIE,iCAAAA;AAAA,OACK;AACP,SAAS,SAAS;;;ACNlB;AAAA,EAEE;AAAA,OACK;AAIA,SAAS,+BACd,QACqB;AACrB,QAAM,WAAgC,CAAC;AAEvC,aAAW,EAAE,MAAM,QAAQ,KAAK,QAAQ;AACtC,YAAQ,MAAM;AAAA,MACZ,KAAK,UAAU;AACb,iBAAS,KAAK,EAAE,MAAM,UAAU,QAAQ,CAAC;AACzC;AAAA,MACF;AAAA,MAEA,KAAK,QAAQ;AACX,iBAAS,KAAK;AAAA,UACZ,MAAM;AAAA,UACN,SAAS,QACN,IAAI,CAAC,SAAS;AACb,oBAAQ,KAAK,MAAM;AAAA,cACjB,KAAK,QAAQ;AACX,uBAAO,KAAK;AAAA,cACd;AAAA,cACA,KAAK,SAAS;AACZ,sBAAM,IAAI,8BAA8B;AAAA,kBACtC,eAAe;AAAA,gBACjB,CAAC;AAAA,cACH;AAAA,YACF;AAAA,UACF,CAAC,EACA,KAAK,EAAE;AAAA,QACZ,CAAC;AACD;AAAA,MACF;AAAA,MAEA,KAAK,aAAa;AAChB,YAAI,OAAO;AACX,cAAM,YAID,CAAC;AAEN,mBAAW,QAAQ,SAAS;AAC1B,kBAAQ,KAAK,MAAM;AAAA,YACjB,KAAK,QAAQ;AACX,sBAAQ,KAAK;AACb;AAAA,YACF;AAAA,YACA,KAAK,aAAa;AAChB,qBAAO,KAAK,UAAU,EAAE,MAAM,KAAK,UAAU,YAAY,KAAK,KAAK,CAAC;AAEpE,wBAAU,KAAK;AAAA,gBACb,IAAI,KAAK;AAAA,gBACT,MAAM;AAAA,gBACN,UAAU;AAAA,kBACR,MAAM,KAAK;AAAA,kBACX,WAAW,KAAK,UAAU,KAAK,IAAI;AAAA,gBACrC;AAAA,cACF,CAAC;AACD;AAAA,YACF;AAAA,YACA,SAAS;AACP,oBAAM,kBAAkB;AACxB,oBAAM,IAAI,MAAM,qBAAqB,eAAe,EAAE;AAAA,YACxD;AAAA,UACF;AAAA,QACF;AAEA,iBAAS,KAAK;AAAA,UACZ,MAAM;AAAA,UACN,SAAS;AAAA,UACT,YACE,UAAU,SAAS,IACf,UAAU,IAAI,CAAC,EAAE,UAAU,EAAE,MAAM,WAAW,KAAK,EAAE,OAAO;AAAA,YAC1D,IAAI;AAAA,YACJ,MAAM;AAAA,YACN,UAAU,EAAE,MAAM,WAAW,KAAK;AAAA,UACpC,EAAE,IACF;AAAA,QACR,CAAC;AAED;AAAA,MACF;AAAA,MACA,KAAK,QAAQ;AACX,mBAAW,gBAAgB,SAAS;AAClC,mBAAS,KAAK;AAAA,YACZ,MAAM;AAAA,YACN,MAAM,aAAa;AAAA,YACnB,SAAS,KAAK,UAAU,aAAa,MAAM;AAAA,UAC7C,CAAC;AAAA,QACH;AACA;AAAA,MACF;AAAA,MACA,SAAS;AACP,cAAM,kBAAkB;AACxB,cAAM,IAAI,MAAM,qBAAqB,eAAe,EAAE;AAAA,MACxD;AAAA,IACF;AAAA,EACF;AAEA,SAAO;AACT;;;ADhGA,SAAS,cAAc;;;AEXhB,SAAS,kBAAkB,QAAuD;AACxF,QAAM,QAAS,OAA0E,SAAS;AAAA,IACjG,eAAe;AAAA,IACf,mBAAmB;AAAA,EACpB;AAEA,SAAO;AAAA,IACN,cAAc,MAAM;AAAA,IACpB,kBAAkB,MAAM;AAAA,EACzB;AACD;;;AFWO,IAAM,6BAAN,MAA4D;AAAA,EASjE,YACE,SACA,UACA,QACA;AAZF,wBAAS,wBAAuB;AAChC,wBAAS,+BAA8B;AAEvC,wBAAS;AACT,wBAAS;AAET,wBAAiB;AAOf,SAAK,UAAU;AACf,SAAK,WAAW;AAChB,SAAK,SAAS;AAAA,EAChB;AAAA,EAEA,IAAI,WAAmB;AACrB,WAAO,KAAK,OAAO;AAAA,EACrB;AAAA,EAEQ,QAAQ;AAAA,IACd;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,IACA;AAAA,EACF,GAAiD;AAC/C,UAAM,OAAO,KAAK;AAElB,UAAM,WAAyC,CAAC;AAEhD,QAAI,oBAAoB,MAAM;AAC5B,eAAS,KAAK;AAAA,QACZ,MAAM;AAAA,QACN,SAAS;AAAA,MACX,CAAC;AAAA,IACH;AAEA,QAAI,mBAAmB,MAAM;AAC3B,eAAS,KAAK;AAAA,QACZ,MAAM;AAAA,QACN,SAAS;AAAA,MACX,CAAC;AAAA,IACH;AAEA,UAAM,WAAW;AAAA;AAAA,MAEf,OAAO,KAAK;AAAA;AAAA,MAGZ,aAAa,KAAK,SAAS;AAAA;AAAA,MAG3B,YAAY;AAAA,MACZ;AAAA,MACA,OAAO;AAAA,MACP,aAAa;AAAA;AAAA,MAGb,UAAU,+BAA+B,MAAM;AAAA,IACjD;AAEA,YAAQ,MAAM;AAAA,MACZ,KAAK,WAAW;AACd,eAAO;AAAA,UACL,MAAM,EAAE,GAAG,UAAU,GAAG,0BAA0B,IAAI,EAAE;AAAA,UACxD;AAAA,QACF;AAAA,MACF;AAAA,MAEA,KAAK,eAAe;AAClB,eAAO;AAAA,UACL,MAAM;AAAA,YACJ,GAAG;AAAA,YACH,iBAAiB;AAAA,cACf,MAAM;AAAA,cACN,aAAa,KAAK;AAAA,YACpB;AAAA,YACA,OAAO;AAAA,UACT;AAAA,UACA;AAAA,QACF;AAAA,MACF;AAAA,MAEA,KAAK,eAAe;AAClB,eAAO;AAAA,UACL,MAAM;AAAA,YACJ,GAAG;AAAA,YACH,aAAa;AAAA,YACb,OAAO,CAAC,EAAE,MAAM,YAAY,UAAU,KAAK,KAAK,CAAC;AAAA,UACnD;AAAA,UACA;AAAA,QACF;AAAA,MACF;AAAA;AAAA;AAAA,MAIA,KAAK,kBAAkB;AACrB,cAAM,IAAIC,+BAA8B;AAAA,UACtC,eAAe;AAAA,QACjB,CAAC;AAAA,MACH;AAAA,MAEA,SAAS;AACP,cAAM,kBAAkB;AACxB,cAAM,IAAI,MAAM,qBAAqB,eAAe,EAAE;AAAA,MACxD;AAAA,IACF;AAAA,EACF;AAAA,EAEA,MAAM,WACJ,SAC6D;AAC7D,UAAM,EAAE,MAAM,SAAS,IAAI,KAAK,QAAQ,OAAO;AAE/C,UAAM,SAAS,MAAM,KAAK,OAAO,QAAQ;AAAA,MACvC,KAAK;AAAA,MACL;AAAA,QACE,UAAU,KAAK;AAAA,QACf,YAAY,KAAK;AAAA,QACjB,aAAa,KAAK;AAAA,QAClB,OAAO,KAAK;AAAA,QACZ,OAAO,KAAK;AAAA;AAAA,QAEZ,iBAAiB,KAAK;AAAA,MACxB;AAAA,MACA,EAAE,SAAS,KAAK,OAAO,WAAW,KAAK,SAAS,QAAQ;AAAA,IAC1D;AAEA,QAAI,kBAAkB,gBAAgB;AACpC,YAAM,IAAI,MAAM,uBAAuB;AAAA,IACzC;AAEA,WAAO;AAAA,MACL,MACE,OAAO,OAAO,aAAa,YAAY,OAAO,aAAa,OACvD,KAAK,UAAU,OAAO,QAAQ,IAC9B,OAAO;AAAA,MACb,WAAW,OAAO,YAAY,IAAI,CAAC,cAAc;AAAA,QAC/C,cAAc;AAAA,QACd,YAAY,SAAS;AAAA,QACrB,UAAU,SAAS;AAAA,QACnB,MAAM,KAAK,UAAU,SAAS,aAAa,CAAC,CAAC;AAAA,MAC/C,EAAE;AAAA,MACF,cAAc;AAAA;AAAA,MACd,SAAS,EAAE,WAAW,KAAK,UAAU,aAAa,KAAK;AAAA,MACvD,OAAO,kBAAkB,MAAM;AAAA,MAC/B;AAAA,IACF;AAAA,EACF;AAAA,EAEA,MAAM,SACJ,SAC2D;AAC3D,UAAM,EAAE,MAAM,SAAS,IAAI,KAAK,QAAQ,OAAO;AAK/C,QAAI,KAAK,OAAO,UAAU,mBAAmB,KAAK,QAAQ,GAAG;AAC3D,YAAMC,YAAW,MAAM,KAAK,WAAW,OAAO;AAE9C,UAAKA,qBAAoB,gBAAiB;AACxC,cAAM,IAAI,MAAM,uBAAuB;AAAA,MACzC;AAEA,aAAO;AAAA,QACL,QAAQ,IAAI,eAA0C;AAAA,UACpD,MAAM,MAAM,YAAY;AACtB,gBAAIA,UAAS,MAAM;AACjB,yBAAW,QAAQ;AAAA,gBACjB,MAAM;AAAA,gBACN,WAAWA,UAAS;AAAA,cACtB,CAAC;AAAA,YACH;AACA,gBAAIA,UAAS,WAAW;AACtB,yBAAW,YAAYA,UAAS,WAAW;AACzC,2BAAW,QAAQ;AAAA,kBACjB,MAAM;AAAA,kBACN,GAAG;AAAA,gBACL,CAAC;AAAA,cACH;AAAA,YACF;AACA,uBAAW,QAAQ;AAAA,cACjB,MAAM;AAAA,cACN,cAAc;AAAA,cACd,OAAOA,UAAS;AAAA,YAClB,CAAC;AACD,uBAAW,MAAM;AAAA,UACnB;AAAA,QACF,CAAC;AAAA,QACD,SAAS,EAAE,WAAW,KAAK,UAAU,aAAa,KAAK;AAAA,QACvD;AAAA,MACF;AAAA,IACF;AAGA,UAAM,WAAW,MAAM,KAAK,OAAO,QAAQ;AAAA,MACzC,KAAK;AAAA,MACL;AAAA,QACE,UAAU,KAAK;AAAA,QACf,YAAY,KAAK;AAAA,QACjB,QAAQ;AAAA,QACR,aAAa,KAAK;AAAA,QAClB,OAAO,KAAK;AAAA,QACZ,OAAO,KAAK;AAAA;AAAA,QAEZ,iBAAiB,KAAK;AAAA,MACxB;AAAA,MACA,EAAE,SAAS,KAAK,OAAO,WAAW,KAAK,SAAS,QAAQ;AAAA,IAC1D;AAEA,QAAI,EAAE,oBAAoB,iBAAiB;AACzC,YAAM,IAAI,MAAM,uBAAuB;AAAA,IACzC;AAEA,UAAM,aAAa,OAAO,IAAI,SAAS,QAAQ,CAAC;AAChD,QAAI,QAAQ,EAAE,cAAc,GAAG,kBAAkB,EAAE;AAEnD,WAAO;AAAA,MACL,QAAQ,IAAI,eAA0C;AAAA,QACpD,MAAM,MAAM,YAAY;AACtB,2BAAiB,SAAS,YAAY;AACpC,gBAAI,CAAC,MAAM,MAAM;AACf;AAAA,YACF;AACA,gBAAI,MAAM,SAAS,UAAU;AAC3B;AAAA,YACF;AACA,kBAAM,QAAQ,KAAK,MAAM,MAAM,IAAI;AACnC,gBAAI,MAAM,OAAO;AACf,sBAAQ,kBAAkB,KAAK;AAAA,YACjC;AACA,kBAAM,SAAS,UACb,WAAW,QAAQ;AAAA,cACjB,MAAM;AAAA,cACN,WAAW,MAAM;AAAA,YACnB,CAAC;AAAA,UACL;AACA,qBAAW,QAAQ;AAAA,YACjB,MAAM;AAAA,YACN,cAAc;AAAA,YACd;AAAA,UACF,CAAC;AACD,qBAAW,MAAM;AAAA,QACnB;AAAA,MACF,CAAC;AAAA,MACD,SAAS,EAAE,WAAW,KAAK,UAAU,aAAa,KAAK;AAAA,MACvD;AAAA,IACF;AAAA,EACF;AACF;AAGA,IAAM,8BAA8B,EAAE,OAAO;AAAA,EAC3C,UAAU,EAAE,OAAO;AACrB,CAAC;AAID,IAAM,2BAA2B,EAAE,WAAW,UAAU;AAExD,SAAS,0BACP,MAGA;AAEA,QAAM,QAAQ,KAAK,OAAO,SAAS,KAAK,QAAQ;AAEhD,MAAI,SAAS,MAAM;AACjB,WAAO,EAAE,OAAO,QAAW,aAAa,OAAU;AAAA,EACpD;AAEA,QAAM,cAAc,MAAM,IAAI,CAAC,UAAU;AAAA,IACvC,MAAM;AAAA,IACN,UAAU;AAAA,MACR,MAAM,KAAK;AAAA;AAAA,MAEX,aAAa,KAAK;AAAA;AAAA,MAElB,YAAY,KAAK;AAAA,IACnB;AAAA,EACF,EAAE;AAEF,QAAM,aAAa,KAAK;AAExB,MAAI,cAAc,MAAM;AACtB,WAAO,EAAE,OAAO,aAAa,aAAa,OAAU;AAAA,EACtD;AAEA,QAAM,OAAO,WAAW;AAExB,UAAQ,MAAM;AAAA,IACZ,KAAK;AACH,aAAO,EAAE,OAAO,aAAa,aAAa,KAAK;AAAA,IACjD,KAAK;AACH,aAAO,EAAE,OAAO,aAAa,aAAa,KAAK;AAAA,IACjD,KAAK;AACH,aAAO,EAAE,OAAO,aAAa,aAAa,MAAM;AAAA;AAAA;AAAA,IAIlD,KAAK;AACH,aAAO;AAAA,QACL,OAAO,YAAY;AAAA,UACjB,CAAC,SAAS,KAAK,SAAS,SAAS,WAAW;AAAA,QAC9C;AAAA,QACA,aAAa;AAAA,MACf;AAAA,IACF,SAAS;AACP,YAAM,kBAAkB;AACxB,YAAM,IAAI,MAAM,iCAAiC,eAAe,EAAE;AAAA,IACpE;AAAA,EACF;AACF;AAEA,SAAS,mBAAmB,UAA+B;AACzD,SAAO,SAAS,SAAS,KAAK,SAAS,SAAS,SAAS,CAAC,EAAE,SAAS;AACvE;;;AGlVO,SAAS,UAAU,WAAmB,QAAuB;AAClE,SAAO,OAAoC,OAAa,QAAkC,YAAoC;AAC5H,UAAM,MAAM,iDAAiD,SAAS,WAAW,KAAK;AACtF,UAAM,OAAO,KAAK,UAAU,MAAM;AAElC,UAAM,UAAU;AAAA,MACd,gBAAgB;AAAA,MAChB,eAAe,UAAU,MAAM;AAAA,IACjC;AAEA,UAAM,WAAW,MAAM,MAAM,KAAK;AAAA,MAChC,QAAQ;AAAA,MACR;AAAA,MACA;AAAA,IACF,CAAC;AAED,QAAI,SAAS,mBAAmB;AAC9B,aAAO;AAAA,IACT;AAEA,QAAK,OAAiC,WAAW,MAAM;AAErD,UAAI,SAAS,MAAM;AACjB,eAAO,SAAS;AAAA,MAClB;AACA,YAAM,IAAI,MAAM,2CAA2C;AAAA,IAC7D;AAGA,UAAM,OAAO,MAAM,SAAS,KAAyD;AACrF,WAAO,KAAK;AAAA,EACd;AACF;;;ACaO,SAAS,gBAAgB,SAAuC;AAGrE,MAAI;AAEJ,MAAI,QAAQ,SAAS;AACnB,cAAU,QAAQ;AAAA,EACpB,OAAO;AACL,UAAM,EAAE,WAAW,OAAO,IAAI;AAC9B,cAAU;AAAA,MACR,KAAK,UAAU,WAAW,MAAM;AAAA,IAClC;AAAA,EACF;AAEA,MAAI,CAAC,SAAS;AACZ,UAAM,IAAI;AAAA,MACR;AAAA,IACF;AAAA,EACF;AAKA,QAAM,kBAAkB,CACtB,SACA,WAAkC,CAAC,MAEnC,IAAI,2BAA2B,SAAS,UAAU;AAAA,IAChD,UAAU;AAAA,IACV;AAAA,IACA,SAAS,QAAQ;AAAA,EACnB,CAAC;AAEH,QAAM,WAAW,SACf,SACA,UACA;AACA,QAAI,YAAY;AACd,YAAM,IAAI;AAAA,QACR;AAAA,MACF;AAAA,IACF;AACA,WAAO,gBAAgB,SAAS,QAAQ;AAAA,EAC1C;AAEA,WAAS,OAAO;AAEhB,SAAO;AACT;","names":["UnsupportedFunctionalityError","UnsupportedFunctionalityError","response"]}
package/package.json CHANGED
@@ -2,7 +2,7 @@
2
2
  "name": "workers-ai-provider",
3
3
  "description": "Workers AI Provider for the vercel AI SDK",
4
4
  "type": "module",
5
- "version": "0.1.2",
5
+ "version": "0.2.0",
6
6
  "main": "dist/index.js",
7
7
  "types": "dist/index.d.ts",
8
8
  "repository": {
package/src/index.ts CHANGED
@@ -1,6 +1,37 @@
1
1
  import { WorkersAIChatLanguageModel } from "./workersai-chat-language-model";
2
2
  import type { WorkersAIChatSettings } from "./workersai-chat-settings";
3
3
  import type { TextGenerationModels } from "./workersai-models";
4
+ import { createRun } from "./utils";
5
+
6
+ export type WorkersAISettings =
7
+ ({
8
+ /**
9
+ * Provide a Cloudflare AI binding.
10
+ */
11
+ binding: Ai;
12
+
13
+ /**
14
+ * Credentials must be absent when a binding is given.
15
+ */
16
+ accountId?: never;
17
+ apiKey?: never;
18
+ }
19
+ | {
20
+ /**
21
+ * Provide Cloudflare API credentials directly. Must be used if a binding is not specified.
22
+ */
23
+ accountId: string;
24
+ apiKey: string;
25
+ /**
26
+ * Both binding must be absent if credentials are used directly.
27
+ */
28
+ binding?: never;
29
+ }) & {
30
+ /**
31
+ * Optionally specify a gateway.
32
+ */
33
+ gateway?: GatewayOptions;
34
+ };
4
35
 
5
36
  export interface WorkersAI {
6
37
  (
@@ -17,36 +48,40 @@ export interface WorkersAI {
17
48
  ): WorkersAIChatLanguageModel;
18
49
  }
19
50
 
20
- export interface WorkersAISettings {
21
- /**
22
- * Provide an `env.AI` binding to use for the AI inference.
23
- * You can set up an AI bindings in your Workers project
24
- * by adding the following this to `wrangler.toml`:
25
-
26
- ```toml
27
- [ai]
28
- binding = "AI"
29
- ```
30
- **/
31
- binding: Ai;
32
- /**
33
- * Optionally set Cloudflare AI Gateway options.
34
- */
35
- gateway?: GatewayOptions;
36
- }
37
-
38
51
  /**
39
52
  * Create a Workers AI provider instance.
40
- **/
53
+ */
41
54
  export function createWorkersAI(options: WorkersAISettings): WorkersAI {
55
+ // Use a binding if one is directly provided. Otherwise use credentials to create
56
+ // a `run` method that calls the Cloudflare REST API.
57
+ let binding: Ai | undefined;
58
+
59
+ if (options.binding) {
60
+ binding = options.binding;
61
+ } else {
62
+ const { accountId, apiKey } = options;
63
+ binding = {
64
+ run: createRun(accountId, apiKey),
65
+ } as Ai;
66
+ }
67
+
68
+ if (!binding) {
69
+ throw new Error(
70
+ "Either a binding or credentials must be provided."
71
+ );
72
+ }
73
+
74
+ /**
75
+ * Helper function to create a chat model instance.
76
+ */
42
77
  const createChatModel = (
43
78
  modelId: TextGenerationModels,
44
79
  settings: WorkersAIChatSettings = {}
45
80
  ) =>
46
81
  new WorkersAIChatLanguageModel(modelId, settings, {
47
82
  provider: "workersai.chat",
48
- binding: options.binding,
49
- gateway: options.gateway,
83
+ binding,
84
+ gateway: options.gateway
50
85
  });
51
86
 
52
87
  const provider = function (
@@ -58,7 +93,6 @@ export function createWorkersAI(options: WorkersAISettings): WorkersAI {
58
93
  "The WorkersAI model function cannot be called with the new keyword."
59
94
  );
60
95
  }
61
-
62
96
  return createChatModel(modelId, settings);
63
97
  };
64
98
 
@@ -66,3 +100,4 @@ export function createWorkersAI(options: WorkersAISettings): WorkersAI {
66
100
 
67
101
  return provider;
68
102
  }
103
+
package/src/utils.ts ADDED
@@ -0,0 +1,64 @@
1
+ /**
2
+ * Creates a run method that mimics the Cloudflare Workers AI binding,
3
+ * but uses the Cloudflare REST API under the hood.
4
+ *
5
+ * @param accountId - Your Cloudflare account identifier.
6
+ * @param apiKey - Your Cloudflare API token/key with appropriate permissions.
7
+ * @returns An function matching `Ai['run']`.
8
+ */
9
+ export function createRun(accountId: string, apiKey: string): AiRun {
10
+ return async <Name extends keyof AiModels>(model: Name, inputs: AiModels[Name]["inputs"], options?: AiOptions | undefined) => {
11
+ const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}`;
12
+ const body = JSON.stringify(inputs);
13
+
14
+ const headers = {
15
+ "Content-Type": "application/json",
16
+ Authorization: `Bearer ${apiKey}`,
17
+ };
18
+
19
+ const response = await fetch(url, {
20
+ method: "POST",
21
+ headers,
22
+ body,
23
+ }) as Response;
24
+
25
+ if (options?.returnRawResponse) {
26
+ return response;
27
+ }
28
+
29
+ if ((inputs as AiTextGenerationInput).stream === true) {
30
+ // If there's a stream, return the raw body so the caller can process it
31
+ if (response.body) {
32
+ return response.body;
33
+ }
34
+ throw new Error("No readable body available for streaming.");
35
+ }
36
+
37
+ // Otherwise, parse JSON and return the data.result
38
+ const data = await response.json<{ result: AiModels[Name]["postProcessedOutputs"] }>();
39
+ return data.result;
40
+ };
41
+ }
42
+
43
+ interface AiRun {
44
+ // (1) Return raw response if `options.returnRawResponse` is `true`.
45
+ <Name extends keyof AiModels>(
46
+ model: Name,
47
+ inputs: AiModels[Name]["inputs"],
48
+ options: AiOptions & { returnRawResponse: true }
49
+ ): Promise<Response>;
50
+
51
+ // (2) Return a stream if the input has `stream: true`.
52
+ <Name extends keyof AiModels>(
53
+ model: Name,
54
+ inputs: AiModels[Name]["inputs"] & { stream: true },
55
+ options?: AiOptions
56
+ ): Promise<ReadableStream<Uint8Array>>;
57
+
58
+ // (3) Return the post-processed outputs by default.
59
+ <Name extends keyof AiModels>(
60
+ model: Name,
61
+ inputs: AiModels[Name]["inputs"],
62
+ options?: AiOptions
63
+ ): Promise<AiModels[Name]["postProcessedOutputs"]>;
64
+ }
@@ -11,6 +11,7 @@ import type { TextGenerationModels } from "./workersai-models";
11
11
 
12
12
  import { events } from "fetch-event-stream";
13
13
  import { mapWorkersAIUsage } from "./map-workersai-usage";
14
+ import type { WorkersAIChatPrompt } from "./workersai-chat-prompt";
14
15
 
15
16
  type WorkersAIChatConfig = {
16
17
  provider: string;
@@ -180,6 +181,47 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 {
180
181
  ): Promise<Awaited<ReturnType<LanguageModelV1["doStream"]>>> {
181
182
  const { args, warnings } = this.getArgs(options);
182
183
 
184
+ // [1] When the latest message is not a tool response, we use the regular generate function
185
+ // and simulate it as a streamed response in order to satisfy the AI SDK's interface for
186
+ // doStream...
187
+ if (args.tools?.length && lastMessageWasUser(args.messages)) {
188
+ const response = await this.doGenerate(options);
189
+
190
+ if ((response instanceof ReadableStream)) {
191
+ throw new Error("This shouldn't happen");
192
+ }
193
+
194
+ return {
195
+ stream: new ReadableStream<LanguageModelV1StreamPart>({
196
+ async start(controller) {
197
+ if (response.text) {
198
+ controller.enqueue({
199
+ type: "text-delta",
200
+ textDelta: response.text,
201
+ })
202
+ }
203
+ if (response.toolCalls) {
204
+ for (const toolCall of response.toolCalls) {
205
+ controller.enqueue({
206
+ type: "tool-call",
207
+ ...toolCall,
208
+ })
209
+ }
210
+ }
211
+ controller.enqueue({
212
+ type: "finish",
213
+ finishReason: "stop",
214
+ usage: response.usage,
215
+ });
216
+ controller.close();
217
+ },
218
+ }),
219
+ rawCall: { rawPrompt: args.messages, rawSettings: args },
220
+ warnings,
221
+ };
222
+ }
223
+
224
+ // [2] ...otherwise, we just proceed as normal and stream the response directly from the remote model.
183
225
  const response = await this.config.binding.run(
184
226
  args.model,
185
227
  {
@@ -189,6 +231,8 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 {
189
231
  temperature: args.temperature,
190
232
  tools: args.tools,
191
233
  top_p: args.top_p,
234
+ // @ts-expect-error response_format not yet added to types
235
+ response_format: args.response_format,
192
236
  },
193
237
  { gateway: this.config.gateway ?? this.settings.gateway }
194
238
  );
@@ -297,3 +341,7 @@ function prepareToolsAndToolChoice(
297
341
  }
298
342
  }
299
343
  }
344
+
345
+ function lastMessageWasUser(messages: WorkersAIChatPrompt) {
346
+ return messages.length > 0 && messages[messages.length - 1].role === "user";
347
+ }