workers-ai-provider 0.1.3 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -22,8 +22,7 @@ binding = "AI"
22
22
  Then in your Worker, import the factory function and create a new AI provider:
23
23
 
24
24
  ```ts
25
- // index.ts
26
- import { createWorkersAI } from "workers-ai-provider";
25
+ import { createWorkersAI } from "../../../packages/workers-ai-provider/src";
27
26
  import { streamText } from "ai";
28
27
 
29
28
  type Env = {
@@ -58,6 +57,25 @@ export default {
58
57
  };
59
58
  ```
60
59
 
60
+ You can also use your Cloudflare credentials to create the provider, for example if you want to use Cloudflare AI outside of the Worker environment. For example, here is how you can use Cloudflare AI in a Node script:
61
+
62
+ ```js
63
+ const workersai = createWorkersAI({
64
+ accountId: process.env.CLOUDFLARE_ACCOUNT_ID,
65
+ apiKey: process.env.CLOUDFLARE_API_KEY
66
+ });
67
+
68
+ const text = await streamText({
69
+ model: workersai("@cf/meta/llama-2-7b-chat-int8"),
70
+ messages: [
71
+ {
72
+ role: "user",
73
+ content: "Write an essay about hello world",
74
+ },
75
+ ],
76
+ });
77
+ ```
78
+
61
79
  For more info, refer to the documentation of the [Vercel AI SDK](https://sdk.vercel.ai/).
62
80
 
63
81
  ### Credits
package/dist/index.d.ts CHANGED
@@ -1,11 +1,10 @@
1
- import { LanguageModelV1 } from '@ai-sdk/provider';
1
+ import { LanguageModelV1, ImageModelV1 } from '@ai-sdk/provider';
2
2
 
3
3
  interface WorkersAIChatSettings {
4
4
  /**
5
- Whether to inject a safety prompt before all conversations.
6
-
7
- Defaults to `false`.
8
- */
5
+ * Whether to inject a safety prompt before all conversations.
6
+ * Defaults to `false`.
7
+ */
9
8
  safePrompt?: boolean;
10
9
  /**
11
10
  * Optionally set Cloudflare AI Gateway options.
@@ -18,6 +17,7 @@ interface WorkersAIChatSettings {
18
17
  * The names of the BaseAiTextGeneration models.
19
18
  */
20
19
  type TextGenerationModels = Exclude<value2key<AiModels, BaseAiTextGeneration>, value2key<AiModels, BaseAiTextToImage>>;
20
+ type ImageGenerationModels = value2key<AiModels, BaseAiTextToImage>;
21
21
  type value2key<T, V> = {
22
22
  [K in keyof T]: T[K] extends V ? K : never;
23
23
  }[keyof T];
@@ -40,33 +40,67 @@ declare class WorkersAIChatLanguageModel implements LanguageModelV1 {
40
40
  doStream(options: Parameters<LanguageModelV1["doStream"]>[0]): Promise<Awaited<ReturnType<LanguageModelV1["doStream"]>>>;
41
41
  }
42
42
 
43
+ type WorkersAIImageConfig = {
44
+ provider: string;
45
+ binding: Ai;
46
+ gateway?: GatewayOptions;
47
+ };
48
+
49
+ type WorkersAIImageSettings = {
50
+ maxImagesPerCall?: number;
51
+ };
52
+
53
+ declare class WorkersAIImageModel implements ImageModelV1 {
54
+ readonly modelId: ImageGenerationModels;
55
+ readonly settings: WorkersAIImageSettings;
56
+ readonly config: WorkersAIImageConfig;
57
+ readonly specificationVersion = "v1";
58
+ get maxImagesPerCall(): number;
59
+ get provider(): string;
60
+ constructor(modelId: ImageGenerationModels, settings: WorkersAIImageSettings, config: WorkersAIImageConfig);
61
+ doGenerate({ prompt, n, size, aspectRatio, seed, }: Parameters<ImageModelV1["doGenerate"]>[0]): Promise<Awaited<ReturnType<ImageModelV1["doGenerate"]>>>;
62
+ }
63
+
64
+ type WorkersAISettings = ({
65
+ /**
66
+ * Provide a Cloudflare AI binding.
67
+ */
68
+ binding: Ai;
69
+ /**
70
+ * Credentials must be absent when a binding is given.
71
+ */
72
+ accountId?: never;
73
+ apiKey?: never;
74
+ } | {
75
+ /**
76
+ * Provide Cloudflare API credentials directly. Must be used if a binding is not specified.
77
+ */
78
+ accountId: string;
79
+ apiKey: string;
80
+ /**
81
+ * Both binding must be absent if credentials are used directly.
82
+ */
83
+ binding?: never;
84
+ }) & {
85
+ /**
86
+ * Optionally specify a gateway.
87
+ */
88
+ gateway?: GatewayOptions;
89
+ };
43
90
  interface WorkersAI {
44
91
  (modelId: TextGenerationModels, settings?: WorkersAIChatSettings): WorkersAIChatLanguageModel;
45
92
  /**
46
93
  * Creates a model for text generation.
47
94
  **/
48
95
  chat(modelId: TextGenerationModels, settings?: WorkersAIChatSettings): WorkersAIChatLanguageModel;
49
- }
50
- interface WorkersAISettings {
51
96
  /**
52
- * Provide an `env.AI` binding to use for the AI inference.
53
- * You can set up an AI bindings in your Workers project
54
- * by adding the following this to `wrangler.toml`:
55
-
56
- ```toml
57
- [ai]
58
- binding = "AI"
59
- ```
97
+ * Creates a model for image generation.
60
98
  **/
61
- binding: Ai;
62
- /**
63
- * Optionally set Cloudflare AI Gateway options.
64
- */
65
- gateway?: GatewayOptions;
99
+ image(modelId: ImageGenerationModels, settings?: WorkersAIImageSettings): WorkersAIImageModel;
66
100
  }
67
101
  /**
68
102
  * Create a Workers AI provider instance.
69
- **/
103
+ */
70
104
  declare function createWorkersAI(options: WorkersAISettings): WorkersAI;
71
105
 
72
106
  export { type WorkersAI, type WorkersAISettings, createWorkersAI };
package/dist/index.js CHANGED
@@ -1,17 +1,161 @@
1
1
  var __defProp = Object.defineProperty;
2
+ var __typeError = (msg) => {
3
+ throw TypeError(msg);
4
+ };
2
5
  var __defNormalProp = (obj, key, value) => key in obj ? __defProp(obj, key, { enumerable: true, configurable: true, writable: true, value }) : obj[key] = value;
3
6
  var __publicField = (obj, key, value) => __defNormalProp(obj, typeof key !== "symbol" ? key + "" : key, value);
7
+ var __accessCheck = (obj, member, msg) => member.has(obj) || __typeError("Cannot " + msg);
8
+ var __privateGet = (obj, member, getter) => (__accessCheck(obj, member, "read from private field"), getter ? getter.call(obj) : member.get(obj));
9
+ var __privateAdd = (obj, member, value) => member.has(obj) ? __typeError("Cannot add the same private member more than once") : member instanceof WeakSet ? member.add(obj) : member.set(obj, value);
10
+ var __privateSet = (obj, member, value, setter) => (__accessCheck(obj, member, "write to private field"), setter ? setter.call(obj, value) : member.set(obj, value), value);
4
11
 
5
- // src/workersai-chat-language-model.ts
6
- import {
7
- UnsupportedFunctionalityError as UnsupportedFunctionalityError2
8
- } from "@ai-sdk/provider";
9
- import { z } from "zod";
12
+ // src/utils.ts
13
+ function createRun(config) {
14
+ const { accountId, apiKey } = config;
15
+ return async function run(model, inputs, options) {
16
+ const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}`;
17
+ const headers = {
18
+ "Content-Type": "application/json",
19
+ Authorization: `Bearer ${apiKey}`
20
+ };
21
+ const body = JSON.stringify(inputs);
22
+ const response = await fetch(url, {
23
+ method: "POST",
24
+ headers,
25
+ body
26
+ });
27
+ if (options?.returnRawResponse) {
28
+ return response;
29
+ }
30
+ if (inputs.stream === true) {
31
+ if (response.body) {
32
+ return response.body;
33
+ }
34
+ throw new Error("No readable body available for streaming.");
35
+ }
36
+ const data = await response.json();
37
+ return data.result;
38
+ };
39
+ }
40
+
41
+ // ../../node_modules/@ai-sdk/provider/dist/index.mjs
42
+ var marker = "vercel.ai.error";
43
+ var symbol = Symbol.for(marker);
44
+ var _a;
45
+ var _AISDKError = class _AISDKError2 extends Error {
46
+ /**
47
+ * Creates an AI SDK Error.
48
+ *
49
+ * @param {Object} params - The parameters for creating the error.
50
+ * @param {string} params.name - The name of the error.
51
+ * @param {string} params.message - The error message.
52
+ * @param {unknown} [params.cause] - The underlying cause of the error.
53
+ */
54
+ constructor({
55
+ name: name14,
56
+ message,
57
+ cause
58
+ }) {
59
+ super(message);
60
+ this[_a] = true;
61
+ this.name = name14;
62
+ this.cause = cause;
63
+ }
64
+ /**
65
+ * Checks if the given error is an AI SDK Error.
66
+ * @param {unknown} error - The error to check.
67
+ * @returns {boolean} True if the error is an AI SDK Error, false otherwise.
68
+ */
69
+ static isInstance(error) {
70
+ return _AISDKError2.hasMarker(error, marker);
71
+ }
72
+ static hasMarker(error, marker15) {
73
+ const markerSymbol = Symbol.for(marker15);
74
+ return error != null && typeof error === "object" && markerSymbol in error && typeof error[markerSymbol] === "boolean" && error[markerSymbol] === true;
75
+ }
76
+ };
77
+ _a = symbol;
78
+ var AISDKError = _AISDKError;
79
+ var name = "AI_APICallError";
80
+ var marker2 = `vercel.ai.error.${name}`;
81
+ var symbol2 = Symbol.for(marker2);
82
+ var _a2;
83
+ _a2 = symbol2;
84
+ var name2 = "AI_EmptyResponseBodyError";
85
+ var marker3 = `vercel.ai.error.${name2}`;
86
+ var symbol3 = Symbol.for(marker3);
87
+ var _a3;
88
+ _a3 = symbol3;
89
+ var name3 = "AI_InvalidArgumentError";
90
+ var marker4 = `vercel.ai.error.${name3}`;
91
+ var symbol4 = Symbol.for(marker4);
92
+ var _a4;
93
+ _a4 = symbol4;
94
+ var name4 = "AI_InvalidPromptError";
95
+ var marker5 = `vercel.ai.error.${name4}`;
96
+ var symbol5 = Symbol.for(marker5);
97
+ var _a5;
98
+ _a5 = symbol5;
99
+ var name5 = "AI_InvalidResponseDataError";
100
+ var marker6 = `vercel.ai.error.${name5}`;
101
+ var symbol6 = Symbol.for(marker6);
102
+ var _a6;
103
+ _a6 = symbol6;
104
+ var name6 = "AI_JSONParseError";
105
+ var marker7 = `vercel.ai.error.${name6}`;
106
+ var symbol7 = Symbol.for(marker7);
107
+ var _a7;
108
+ _a7 = symbol7;
109
+ var name7 = "AI_LoadAPIKeyError";
110
+ var marker8 = `vercel.ai.error.${name7}`;
111
+ var symbol8 = Symbol.for(marker8);
112
+ var _a8;
113
+ _a8 = symbol8;
114
+ var name8 = "AI_LoadSettingError";
115
+ var marker9 = `vercel.ai.error.${name8}`;
116
+ var symbol9 = Symbol.for(marker9);
117
+ var _a9;
118
+ _a9 = symbol9;
119
+ var name9 = "AI_NoContentGeneratedError";
120
+ var marker10 = `vercel.ai.error.${name9}`;
121
+ var symbol10 = Symbol.for(marker10);
122
+ var _a10;
123
+ _a10 = symbol10;
124
+ var name10 = "AI_NoSuchModelError";
125
+ var marker11 = `vercel.ai.error.${name10}`;
126
+ var symbol11 = Symbol.for(marker11);
127
+ var _a11;
128
+ _a11 = symbol11;
129
+ var name11 = "AI_TooManyEmbeddingValuesForCallError";
130
+ var marker12 = `vercel.ai.error.${name11}`;
131
+ var symbol12 = Symbol.for(marker12);
132
+ var _a12;
133
+ _a12 = symbol12;
134
+ var name12 = "AI_TypeValidationError";
135
+ var marker13 = `vercel.ai.error.${name12}`;
136
+ var symbol13 = Symbol.for(marker13);
137
+ var _a13;
138
+ _a13 = symbol13;
139
+ var name13 = "AI_UnsupportedFunctionalityError";
140
+ var marker14 = `vercel.ai.error.${name13}`;
141
+ var symbol14 = Symbol.for(marker14);
142
+ var _a14;
143
+ var UnsupportedFunctionalityError = class extends AISDKError {
144
+ constructor({
145
+ functionality,
146
+ message = `'${functionality}' functionality not supported.`
147
+ }) {
148
+ super({ name: name13, message });
149
+ this[_a14] = true;
150
+ this.functionality = functionality;
151
+ }
152
+ static isInstance(error) {
153
+ return AISDKError.hasMarker(error, marker14);
154
+ }
155
+ };
156
+ _a14 = symbol14;
10
157
 
11
158
  // src/convert-to-workersai-chat-messages.ts
12
- import {
13
- UnsupportedFunctionalityError
14
- } from "@ai-sdk/provider";
15
159
  function convertToWorkersAIChatMessages(prompt) {
16
160
  const messages = [];
17
161
  for (const { role, content } of prompt) {
@@ -48,7 +192,10 @@ function convertToWorkersAIChatMessages(prompt) {
48
192
  break;
49
193
  }
50
194
  case "tool-call": {
51
- text = JSON.stringify({ name: part.toolName, parameters: part.args });
195
+ text = JSON.stringify({
196
+ name: part.toolName,
197
+ parameters: part.args
198
+ });
52
199
  toolCalls.push({
53
200
  id: part.toolCallId,
54
201
  type: "function",
@@ -68,10 +215,10 @@ function convertToWorkersAIChatMessages(prompt) {
68
215
  messages.push({
69
216
  role: "assistant",
70
217
  content: text,
71
- tool_calls: toolCalls.length > 0 ? toolCalls.map(({ function: { name, arguments: args } }) => ({
218
+ tool_calls: toolCalls.length > 0 ? toolCalls.map(({ function: { name: name14, arguments: args } }) => ({
72
219
  id: "null",
73
220
  type: "function",
74
- function: { name, arguments: args }
221
+ function: { name: name14, arguments: args }
75
222
  })) : void 0
76
223
  });
77
224
  break;
@@ -95,8 +242,98 @@ function convertToWorkersAIChatMessages(prompt) {
95
242
  return messages;
96
243
  }
97
244
 
98
- // src/workersai-chat-language-model.ts
99
- import { events } from "fetch-event-stream";
245
+ // ../../node_modules/fetch-event-stream/esm/deps/jsr.io/@std/streams/0.221.0/text_line_stream.js
246
+ var _currentLine;
247
+ var TextLineStream = class extends TransformStream {
248
+ /** Constructs a new instance. */
249
+ constructor(options = { allowCR: false }) {
250
+ super({
251
+ transform: (chars, controller) => {
252
+ chars = __privateGet(this, _currentLine) + chars;
253
+ while (true) {
254
+ const lfIndex = chars.indexOf("\n");
255
+ const crIndex = options.allowCR ? chars.indexOf("\r") : -1;
256
+ if (crIndex !== -1 && crIndex !== chars.length - 1 && (lfIndex === -1 || lfIndex - 1 > crIndex)) {
257
+ controller.enqueue(chars.slice(0, crIndex));
258
+ chars = chars.slice(crIndex + 1);
259
+ continue;
260
+ }
261
+ if (lfIndex === -1)
262
+ break;
263
+ const endIndex = chars[lfIndex - 1] === "\r" ? lfIndex - 1 : lfIndex;
264
+ controller.enqueue(chars.slice(0, endIndex));
265
+ chars = chars.slice(lfIndex + 1);
266
+ }
267
+ __privateSet(this, _currentLine, chars);
268
+ },
269
+ flush: (controller) => {
270
+ if (__privateGet(this, _currentLine) === "")
271
+ return;
272
+ const currentLine = options.allowCR && __privateGet(this, _currentLine).endsWith("\r") ? __privateGet(this, _currentLine).slice(0, -1) : __privateGet(this, _currentLine);
273
+ controller.enqueue(currentLine);
274
+ }
275
+ });
276
+ __privateAdd(this, _currentLine, "");
277
+ }
278
+ };
279
+ _currentLine = new WeakMap();
280
+
281
+ // ../../node_modules/fetch-event-stream/esm/utils.js
282
+ function stream(input) {
283
+ let decoder = new TextDecoderStream();
284
+ let split2 = new TextLineStream({ allowCR: true });
285
+ return input.pipeThrough(decoder).pipeThrough(split2);
286
+ }
287
+ function split(input) {
288
+ let rgx = /[:]\s*/;
289
+ let match = rgx.exec(input);
290
+ let idx = match && match.index;
291
+ if (idx) {
292
+ return [
293
+ input.substring(0, idx),
294
+ input.substring(idx + match[0].length)
295
+ ];
296
+ }
297
+ }
298
+
299
+ // ../../node_modules/fetch-event-stream/esm/mod.js
300
+ async function* events(res, signal) {
301
+ if (!res.body)
302
+ return;
303
+ let iter = stream(res.body);
304
+ let line, reader = iter.getReader();
305
+ let event;
306
+ for (; ; ) {
307
+ if (signal && signal.aborted) {
308
+ return reader.cancel();
309
+ }
310
+ line = await reader.read();
311
+ if (line.done)
312
+ return;
313
+ if (!line.value) {
314
+ if (event)
315
+ yield event;
316
+ event = void 0;
317
+ continue;
318
+ }
319
+ let [field, value] = split(line.value) || [];
320
+ if (!field)
321
+ continue;
322
+ if (field === "data") {
323
+ event || (event = {});
324
+ event[field] = event[field] ? event[field] + "\n" + value : value;
325
+ } else if (field === "event") {
326
+ event || (event = {});
327
+ event[field] = value;
328
+ } else if (field === "id") {
329
+ event || (event = {});
330
+ event[field] = +value || value;
331
+ } else if (field === "retry") {
332
+ event || (event = {});
333
+ event[field] = +value || void 0;
334
+ }
335
+ }
336
+ }
100
337
 
101
338
  // src/map-workersai-usage.ts
102
339
  function mapWorkersAIUsage(output) {
@@ -195,7 +432,7 @@ var WorkersAIChatLanguageModel = class {
195
432
  // @ts-expect-error - this is unreachable code
196
433
  // TODO: fixme
197
434
  case "object-grammar": {
198
- throw new UnsupportedFunctionalityError2({
435
+ throw new UnsupportedFunctionalityError({
199
436
  functionality: "object-grammar mode"
200
437
  });
201
438
  }
@@ -325,10 +562,6 @@ var WorkersAIChatLanguageModel = class {
325
562
  };
326
563
  }
327
564
  };
328
- var workersAIChatResponseSchema = z.object({
329
- response: z.string()
330
- });
331
- var workersAIChatChunkSchema = z.instanceof(Uint8Array);
332
565
  function prepareToolsAndToolChoice(mode) {
333
566
  const tools = mode.tools?.length ? mode.tools : void 0;
334
567
  if (tools == null) {
@@ -360,9 +593,7 @@ function prepareToolsAndToolChoice(mode) {
360
593
  // so we filter the tools and force the tool choice through 'any'
361
594
  case "tool":
362
595
  return {
363
- tools: mappedTools.filter(
364
- (tool) => tool.function.name === toolChoice.toolName
365
- ),
596
+ tools: mappedTools.filter((tool) => tool.function.name === toolChoice.toolName),
366
597
  tool_choice: "any"
367
598
  };
368
599
  default: {
@@ -375,22 +606,128 @@ function lastMessageWasUser(messages) {
375
606
  return messages.length > 0 && messages[messages.length - 1].role === "user";
376
607
  }
377
608
 
609
+ // src/workersai-image-model.ts
610
+ var WorkersAIImageModel = class {
611
+ constructor(modelId, settings, config) {
612
+ this.modelId = modelId;
613
+ this.settings = settings;
614
+ this.config = config;
615
+ __publicField(this, "specificationVersion", "v1");
616
+ }
617
+ get maxImagesPerCall() {
618
+ return this.settings.maxImagesPerCall ?? 1;
619
+ }
620
+ get provider() {
621
+ return this.config.provider;
622
+ }
623
+ async doGenerate({
624
+ prompt,
625
+ n,
626
+ size,
627
+ aspectRatio,
628
+ seed
629
+ // headers,
630
+ // abortSignal,
631
+ }) {
632
+ const { width, height } = getDimensionsFromSizeString(size);
633
+ const warnings = [];
634
+ if (aspectRatio != null) {
635
+ warnings.push({
636
+ type: "unsupported-setting",
637
+ setting: "aspectRatio",
638
+ details: "This model does not support aspect ratio. Use `size` instead."
639
+ });
640
+ }
641
+ const generateImage = async () => {
642
+ const outputStream = await this.config.binding.run(
643
+ this.modelId,
644
+ {
645
+ prompt,
646
+ seed,
647
+ width,
648
+ height
649
+ }
650
+ );
651
+ return streamToUint8Array(outputStream);
652
+ };
653
+ const images = await Promise.all(
654
+ Array.from({ length: n }, () => generateImage())
655
+ );
656
+ return {
657
+ images,
658
+ warnings,
659
+ response: {
660
+ timestamp: /* @__PURE__ */ new Date(),
661
+ modelId: this.modelId,
662
+ headers: {}
663
+ }
664
+ };
665
+ }
666
+ };
667
+ function getDimensionsFromSizeString(size) {
668
+ const [width, height] = size?.split("x") ?? [void 0, void 0];
669
+ return {
670
+ width: parseInteger(width),
671
+ height: parseInteger(height)
672
+ };
673
+ }
674
+ function parseInteger(value) {
675
+ if (value === "" || !value) return void 0;
676
+ const number = Number(value);
677
+ return Number.isInteger(number) ? number : void 0;
678
+ }
679
+ async function streamToUint8Array(stream2) {
680
+ const reader = stream2.getReader();
681
+ const chunks = [];
682
+ let totalLength = 0;
683
+ while (true) {
684
+ const { done, value } = await reader.read();
685
+ if (done) break;
686
+ chunks.push(value);
687
+ totalLength += value.length;
688
+ }
689
+ const result = new Uint8Array(totalLength);
690
+ let offset = 0;
691
+ for (const chunk of chunks) {
692
+ result.set(chunk, offset);
693
+ offset += chunk.length;
694
+ }
695
+ return result;
696
+ }
697
+
378
698
  // src/index.ts
379
699
  function createWorkersAI(options) {
700
+ let binding;
701
+ if (options.binding) {
702
+ binding = options.binding;
703
+ } else {
704
+ const { accountId, apiKey } = options;
705
+ binding = {
706
+ run: createRun({ accountId, apiKey })
707
+ };
708
+ }
709
+ if (!binding) {
710
+ throw new Error("Either a binding or credentials must be provided.");
711
+ }
380
712
  const createChatModel = (modelId, settings = {}) => new WorkersAIChatLanguageModel(modelId, settings, {
381
713
  provider: "workersai.chat",
382
- binding: options.binding,
714
+ binding,
715
+ gateway: options.gateway
716
+ });
717
+ const createImageModel = (modelId, settings = {}) => new WorkersAIImageModel(modelId, settings, {
718
+ provider: "workersai.image",
719
+ binding,
383
720
  gateway: options.gateway
384
721
  });
385
- const provider = function(modelId, settings) {
722
+ const provider = (modelId, settings) => {
386
723
  if (new.target) {
387
- throw new Error(
388
- "The WorkersAI model function cannot be called with the new keyword."
389
- );
724
+ throw new Error("The WorkersAI model function cannot be called with the new keyword.");
390
725
  }
391
726
  return createChatModel(modelId, settings);
392
727
  };
393
728
  provider.chat = createChatModel;
729
+ provider.image = createImageModel;
730
+ provider.imageModel = createImageModel;
394
731
  return provider;
395
732
  }
396
733
  export {