workers-ai-provider 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -22,8 +22,7 @@ binding = "AI"
22
22
  Then in your Worker, import the factory function and create a new AI provider:
23
23
 
24
24
  ```ts
25
- // index.ts
26
- import { createWorkersAI } from "workers-ai-provider";
25
+ import { createWorkersAI } from "../../../packages/workers-ai-provider/src";
27
26
  import { streamText } from "ai";
28
27
 
29
28
  type Env = {
package/dist/index.d.ts CHANGED
@@ -1,11 +1,10 @@
1
- import { LanguageModelV1 } from '@ai-sdk/provider';
1
+ import { LanguageModelV1, ImageModelV1 } from '@ai-sdk/provider';
2
2
 
3
3
  interface WorkersAIChatSettings {
4
4
  /**
5
- Whether to inject a safety prompt before all conversations.
6
-
7
- Defaults to `false`.
8
- */
5
+ * Whether to inject a safety prompt before all conversations.
6
+ * Defaults to `false`.
7
+ */
9
8
  safePrompt?: boolean;
10
9
  /**
11
10
  * Optionally set Cloudflare AI Gateway options.
@@ -18,6 +17,7 @@ interface WorkersAIChatSettings {
18
17
  * The names of the BaseAiTextGeneration models.
19
18
  */
20
19
  type TextGenerationModels = Exclude<value2key<AiModels, BaseAiTextGeneration>, value2key<AiModels, BaseAiTextToImage>>;
20
+ type ImageGenerationModels = value2key<AiModels, BaseAiTextToImage>;
21
21
  type value2key<T, V> = {
22
22
  [K in keyof T]: T[K] extends V ? K : never;
23
23
  }[keyof T];
@@ -40,6 +40,27 @@ declare class WorkersAIChatLanguageModel implements LanguageModelV1 {
40
40
  doStream(options: Parameters<LanguageModelV1["doStream"]>[0]): Promise<Awaited<ReturnType<LanguageModelV1["doStream"]>>>;
41
41
  }
42
42
 
43
+ type WorkersAIImageConfig = {
44
+ provider: string;
45
+ binding: Ai;
46
+ gateway?: GatewayOptions;
47
+ };
48
+
49
+ type WorkersAIImageSettings = {
50
+ maxImagesPerCall?: number;
51
+ };
52
+
53
+ declare class WorkersAIImageModel implements ImageModelV1 {
54
+ readonly modelId: ImageGenerationModels;
55
+ readonly settings: WorkersAIImageSettings;
56
+ readonly config: WorkersAIImageConfig;
57
+ readonly specificationVersion = "v1";
58
+ get maxImagesPerCall(): number;
59
+ get provider(): string;
60
+ constructor(modelId: ImageGenerationModels, settings: WorkersAIImageSettings, config: WorkersAIImageConfig);
61
+ doGenerate({ prompt, n, size, aspectRatio, seed, }: Parameters<ImageModelV1["doGenerate"]>[0]): Promise<Awaited<ReturnType<ImageModelV1["doGenerate"]>>>;
62
+ }
63
+
43
64
  type WorkersAISettings = ({
44
65
  /**
45
66
  * Provide a Cloudflare AI binding.
@@ -72,6 +93,10 @@ interface WorkersAI {
72
93
  * Creates a model for text generation.
73
94
  **/
74
95
  chat(modelId: TextGenerationModels, settings?: WorkersAIChatSettings): WorkersAIChatLanguageModel;
96
+ /**
97
+ * Creates a model for image generation.
98
+ **/
99
+ image(modelId: ImageGenerationModels, settings?: WorkersAIImageSettings): WorkersAIImageModel;
75
100
  }
76
101
  /**
77
102
  * Create a Workers AI provider instance.
package/dist/index.js CHANGED
@@ -1,17 +1,161 @@
1
1
  var __defProp = Object.defineProperty;
2
+ var __typeError = (msg) => {
3
+ throw TypeError(msg);
4
+ };
2
5
  var __defNormalProp = (obj, key, value) => key in obj ? __defProp(obj, key, { enumerable: true, configurable: true, writable: true, value }) : obj[key] = value;
3
6
  var __publicField = (obj, key, value) => __defNormalProp(obj, typeof key !== "symbol" ? key + "" : key, value);
7
+ var __accessCheck = (obj, member, msg) => member.has(obj) || __typeError("Cannot " + msg);
8
+ var __privateGet = (obj, member, getter) => (__accessCheck(obj, member, "read from private field"), getter ? getter.call(obj) : member.get(obj));
9
+ var __privateAdd = (obj, member, value) => member.has(obj) ? __typeError("Cannot add the same private member more than once") : member instanceof WeakSet ? member.add(obj) : member.set(obj, value);
10
+ var __privateSet = (obj, member, value, setter) => (__accessCheck(obj, member, "write to private field"), setter ? setter.call(obj, value) : member.set(obj, value), value);
4
11
 
5
- // src/workersai-chat-language-model.ts
6
- import {
7
- UnsupportedFunctionalityError as UnsupportedFunctionalityError2
8
- } from "@ai-sdk/provider";
9
- import { z } from "zod";
12
+ // src/utils.ts
13
+ function createRun(config) {
14
+ const { accountId, apiKey } = config;
15
+ return async function run(model, inputs, options) {
16
+ const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}`;
17
+ const headers = {
18
+ "Content-Type": "application/json",
19
+ Authorization: `Bearer ${apiKey}`
20
+ };
21
+ const body = JSON.stringify(inputs);
22
+ const response = await fetch(url, {
23
+ method: "POST",
24
+ headers,
25
+ body
26
+ });
27
+ if (options?.returnRawResponse) {
28
+ return response;
29
+ }
30
+ if (inputs.stream === true) {
31
+ if (response.body) {
32
+ return response.body;
33
+ }
34
+ throw new Error("No readable body available for streaming.");
35
+ }
36
+ const data = await response.json();
37
+ return data.result;
38
+ };
39
+ }
40
+
41
+ // ../../node_modules/@ai-sdk/provider/dist/index.mjs
42
+ var marker = "vercel.ai.error";
43
+ var symbol = Symbol.for(marker);
44
+ var _a;
45
+ var _AISDKError = class _AISDKError2 extends Error {
46
+ /**
47
+ * Creates an AI SDK Error.
48
+ *
49
+ * @param {Object} params - The parameters for creating the error.
50
+ * @param {string} params.name - The name of the error.
51
+ * @param {string} params.message - The error message.
52
+ * @param {unknown} [params.cause] - The underlying cause of the error.
53
+ */
54
+ constructor({
55
+ name: name14,
56
+ message,
57
+ cause
58
+ }) {
59
+ super(message);
60
+ this[_a] = true;
61
+ this.name = name14;
62
+ this.cause = cause;
63
+ }
64
+ /**
65
+ * Checks if the given error is an AI SDK Error.
66
+ * @param {unknown} error - The error to check.
67
+ * @returns {boolean} True if the error is an AI SDK Error, false otherwise.
68
+ */
69
+ static isInstance(error) {
70
+ return _AISDKError2.hasMarker(error, marker);
71
+ }
72
+ static hasMarker(error, marker15) {
73
+ const markerSymbol = Symbol.for(marker15);
74
+ return error != null && typeof error === "object" && markerSymbol in error && typeof error[markerSymbol] === "boolean" && error[markerSymbol] === true;
75
+ }
76
+ };
77
+ _a = symbol;
78
+ var AISDKError = _AISDKError;
79
+ var name = "AI_APICallError";
80
+ var marker2 = `vercel.ai.error.${name}`;
81
+ var symbol2 = Symbol.for(marker2);
82
+ var _a2;
83
+ _a2 = symbol2;
84
+ var name2 = "AI_EmptyResponseBodyError";
85
+ var marker3 = `vercel.ai.error.${name2}`;
86
+ var symbol3 = Symbol.for(marker3);
87
+ var _a3;
88
+ _a3 = symbol3;
89
+ var name3 = "AI_InvalidArgumentError";
90
+ var marker4 = `vercel.ai.error.${name3}`;
91
+ var symbol4 = Symbol.for(marker4);
92
+ var _a4;
93
+ _a4 = symbol4;
94
+ var name4 = "AI_InvalidPromptError";
95
+ var marker5 = `vercel.ai.error.${name4}`;
96
+ var symbol5 = Symbol.for(marker5);
97
+ var _a5;
98
+ _a5 = symbol5;
99
+ var name5 = "AI_InvalidResponseDataError";
100
+ var marker6 = `vercel.ai.error.${name5}`;
101
+ var symbol6 = Symbol.for(marker6);
102
+ var _a6;
103
+ _a6 = symbol6;
104
+ var name6 = "AI_JSONParseError";
105
+ var marker7 = `vercel.ai.error.${name6}`;
106
+ var symbol7 = Symbol.for(marker7);
107
+ var _a7;
108
+ _a7 = symbol7;
109
+ var name7 = "AI_LoadAPIKeyError";
110
+ var marker8 = `vercel.ai.error.${name7}`;
111
+ var symbol8 = Symbol.for(marker8);
112
+ var _a8;
113
+ _a8 = symbol8;
114
+ var name8 = "AI_LoadSettingError";
115
+ var marker9 = `vercel.ai.error.${name8}`;
116
+ var symbol9 = Symbol.for(marker9);
117
+ var _a9;
118
+ _a9 = symbol9;
119
+ var name9 = "AI_NoContentGeneratedError";
120
+ var marker10 = `vercel.ai.error.${name9}`;
121
+ var symbol10 = Symbol.for(marker10);
122
+ var _a10;
123
+ _a10 = symbol10;
124
+ var name10 = "AI_NoSuchModelError";
125
+ var marker11 = `vercel.ai.error.${name10}`;
126
+ var symbol11 = Symbol.for(marker11);
127
+ var _a11;
128
+ _a11 = symbol11;
129
+ var name11 = "AI_TooManyEmbeddingValuesForCallError";
130
+ var marker12 = `vercel.ai.error.${name11}`;
131
+ var symbol12 = Symbol.for(marker12);
132
+ var _a12;
133
+ _a12 = symbol12;
134
+ var name12 = "AI_TypeValidationError";
135
+ var marker13 = `vercel.ai.error.${name12}`;
136
+ var symbol13 = Symbol.for(marker13);
137
+ var _a13;
138
+ _a13 = symbol13;
139
+ var name13 = "AI_UnsupportedFunctionalityError";
140
+ var marker14 = `vercel.ai.error.${name13}`;
141
+ var symbol14 = Symbol.for(marker14);
142
+ var _a14;
143
+ var UnsupportedFunctionalityError = class extends AISDKError {
144
+ constructor({
145
+ functionality,
146
+ message = `'${functionality}' functionality not supported.`
147
+ }) {
148
+ super({ name: name13, message });
149
+ this[_a14] = true;
150
+ this.functionality = functionality;
151
+ }
152
+ static isInstance(error) {
153
+ return AISDKError.hasMarker(error, marker14);
154
+ }
155
+ };
156
+ _a14 = symbol14;
10
157
 
11
158
  // src/convert-to-workersai-chat-messages.ts
12
- import {
13
- UnsupportedFunctionalityError
14
- } from "@ai-sdk/provider";
15
159
  function convertToWorkersAIChatMessages(prompt) {
16
160
  const messages = [];
17
161
  for (const { role, content } of prompt) {
@@ -48,7 +192,10 @@ function convertToWorkersAIChatMessages(prompt) {
48
192
  break;
49
193
  }
50
194
  case "tool-call": {
51
- text = JSON.stringify({ name: part.toolName, parameters: part.args });
195
+ text = JSON.stringify({
196
+ name: part.toolName,
197
+ parameters: part.args
198
+ });
52
199
  toolCalls.push({
53
200
  id: part.toolCallId,
54
201
  type: "function",
@@ -68,10 +215,10 @@ function convertToWorkersAIChatMessages(prompt) {
68
215
  messages.push({
69
216
  role: "assistant",
70
217
  content: text,
71
- tool_calls: toolCalls.length > 0 ? toolCalls.map(({ function: { name, arguments: args } }) => ({
218
+ tool_calls: toolCalls.length > 0 ? toolCalls.map(({ function: { name: name14, arguments: args } }) => ({
72
219
  id: "null",
73
220
  type: "function",
74
- function: { name, arguments: args }
221
+ function: { name: name14, arguments: args }
75
222
  })) : void 0
76
223
  });
77
224
  break;
@@ -95,8 +242,98 @@ function convertToWorkersAIChatMessages(prompt) {
95
242
  return messages;
96
243
  }
97
244
 
98
- // src/workersai-chat-language-model.ts
99
- import { events } from "fetch-event-stream";
245
+ // ../../node_modules/fetch-event-stream/esm/deps/jsr.io/@std/streams/0.221.0/text_line_stream.js
246
+ var _currentLine;
247
+ var TextLineStream = class extends TransformStream {
248
+ /** Constructs a new instance. */
249
+ constructor(options = { allowCR: false }) {
250
+ super({
251
+ transform: (chars, controller) => {
252
+ chars = __privateGet(this, _currentLine) + chars;
253
+ while (true) {
254
+ const lfIndex = chars.indexOf("\n");
255
+ const crIndex = options.allowCR ? chars.indexOf("\r") : -1;
256
+ if (crIndex !== -1 && crIndex !== chars.length - 1 && (lfIndex === -1 || lfIndex - 1 > crIndex)) {
257
+ controller.enqueue(chars.slice(0, crIndex));
258
+ chars = chars.slice(crIndex + 1);
259
+ continue;
260
+ }
261
+ if (lfIndex === -1)
262
+ break;
263
+ const endIndex = chars[lfIndex - 1] === "\r" ? lfIndex - 1 : lfIndex;
264
+ controller.enqueue(chars.slice(0, endIndex));
265
+ chars = chars.slice(lfIndex + 1);
266
+ }
267
+ __privateSet(this, _currentLine, chars);
268
+ },
269
+ flush: (controller) => {
270
+ if (__privateGet(this, _currentLine) === "")
271
+ return;
272
+ const currentLine = options.allowCR && __privateGet(this, _currentLine).endsWith("\r") ? __privateGet(this, _currentLine).slice(0, -1) : __privateGet(this, _currentLine);
273
+ controller.enqueue(currentLine);
274
+ }
275
+ });
276
+ __privateAdd(this, _currentLine, "");
277
+ }
278
+ };
279
+ _currentLine = new WeakMap();
280
+
281
+ // ../../node_modules/fetch-event-stream/esm/utils.js
282
+ function stream(input) {
283
+ let decoder = new TextDecoderStream();
284
+ let split2 = new TextLineStream({ allowCR: true });
285
+ return input.pipeThrough(decoder).pipeThrough(split2);
286
+ }
287
+ function split(input) {
288
+ let rgx = /[:]\s*/;
289
+ let match = rgx.exec(input);
290
+ let idx = match && match.index;
291
+ if (idx) {
292
+ return [
293
+ input.substring(0, idx),
294
+ input.substring(idx + match[0].length)
295
+ ];
296
+ }
297
+ }
298
+
299
+ // ../../node_modules/fetch-event-stream/esm/mod.js
300
+ async function* events(res, signal) {
301
+ if (!res.body)
302
+ return;
303
+ let iter = stream(res.body);
304
+ let line, reader = iter.getReader();
305
+ let event;
306
+ for (; ; ) {
307
+ if (signal && signal.aborted) {
308
+ return reader.cancel();
309
+ }
310
+ line = await reader.read();
311
+ if (line.done)
312
+ return;
313
+ if (!line.value) {
314
+ if (event)
315
+ yield event;
316
+ event = void 0;
317
+ continue;
318
+ }
319
+ let [field, value] = split(line.value) || [];
320
+ if (!field)
321
+ continue;
322
+ if (field === "data") {
323
+ event || (event = {});
324
+ event[field] = event[field] ? event[field] + "\n" + value : value;
325
+ } else if (field === "event") {
326
+ event || (event = {});
327
+ event[field] = value;
328
+ } else if (field === "id") {
329
+ event || (event = {});
330
+ event[field] = +value || value;
331
+ } else if (field === "retry") {
332
+ event || (event = {});
333
+ event[field] = +value || void 0;
334
+ }
335
+ }
336
+ }
100
337
 
101
338
  // src/map-workersai-usage.ts
102
339
  function mapWorkersAIUsage(output) {
@@ -195,7 +432,7 @@ var WorkersAIChatLanguageModel = class {
195
432
  // @ts-expect-error - this is unreachable code
196
433
  // TODO: fixme
197
434
  case "object-grammar": {
198
- throw new UnsupportedFunctionalityError2({
435
+ throw new UnsupportedFunctionalityError({
199
436
  functionality: "object-grammar mode"
200
437
  });
201
438
  }
@@ -325,10 +562,6 @@ var WorkersAIChatLanguageModel = class {
325
562
  };
326
563
  }
327
564
  };
328
- var workersAIChatResponseSchema = z.object({
329
- response: z.string()
330
- });
331
- var workersAIChatChunkSchema = z.instanceof(Uint8Array);
332
565
  function prepareToolsAndToolChoice(mode) {
333
566
  const tools = mode.tools?.length ? mode.tools : void 0;
334
567
  if (tools == null) {
@@ -360,9 +593,7 @@ function prepareToolsAndToolChoice(mode) {
360
593
  // so we filter the tools and force the tool choice through 'any'
361
594
  case "tool":
362
595
  return {
363
- tools: mappedTools.filter(
364
- (tool) => tool.function.name === toolChoice.toolName
365
- ),
596
+ tools: mappedTools.filter((tool) => tool.function.name === toolChoice.toolName),
366
597
  tool_choice: "any"
367
598
  };
368
599
  default: {
@@ -375,33 +606,94 @@ function lastMessageWasUser(messages) {
375
606
  return messages.length > 0 && messages[messages.length - 1].role === "user";
376
607
  }
377
608
 
378
- // src/utils.ts
379
- function createRun(accountId, apiKey) {
380
- return async (model, inputs, options) => {
381
- const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}`;
382
- const body = JSON.stringify(inputs);
383
- const headers = {
384
- "Content-Type": "application/json",
385
- Authorization: `Bearer ${apiKey}`
386
- };
387
- const response = await fetch(url, {
388
- method: "POST",
389
- headers,
390
- body
391
- });
392
- if (options?.returnRawResponse) {
393
- return response;
609
+ // src/workersai-image-model.ts
610
+ var WorkersAIImageModel = class {
611
+ constructor(modelId, settings, config) {
612
+ this.modelId = modelId;
613
+ this.settings = settings;
614
+ this.config = config;
615
+ __publicField(this, "specificationVersion", "v1");
616
+ }
617
+ get maxImagesPerCall() {
618
+ return this.settings.maxImagesPerCall ?? 1;
619
+ }
620
+ get provider() {
621
+ return this.config.provider;
622
+ }
623
+ async doGenerate({
624
+ prompt,
625
+ n,
626
+ size,
627
+ aspectRatio,
628
+ seed
629
+ // headers,
630
+ // abortSignal,
631
+ }) {
632
+ const { width, height } = getDimensionsFromSizeString(size);
633
+ const warnings = [];
634
+ if (aspectRatio != null) {
635
+ warnings.push({
636
+ type: "unsupported-setting",
637
+ setting: "aspectRatio",
638
+ details: "This model does not support aspect ratio. Use `size` instead."
639
+ });
394
640
  }
395
- if (inputs.stream === true) {
396
- if (response.body) {
397
- return response.body;
641
+ const generateImage = async () => {
642
+ const outputStream = await this.config.binding.run(
643
+ this.modelId,
644
+ {
645
+ prompt,
646
+ seed,
647
+ width,
648
+ height
649
+ }
650
+ );
651
+ return streamToUint8Array(outputStream);
652
+ };
653
+ const images = await Promise.all(
654
+ Array.from({ length: n }, () => generateImage())
655
+ );
656
+ return {
657
+ images,
658
+ warnings,
659
+ response: {
660
+ timestamp: /* @__PURE__ */ new Date(),
661
+ modelId: this.modelId,
662
+ headers: {}
398
663
  }
399
- throw new Error("No readable body available for streaming.");
400
- }
401
- const data = await response.json();
402
- return data.result;
664
+ };
665
+ }
666
+ };
667
+ function getDimensionsFromSizeString(size) {
668
+ const [width, height] = size?.split("x") ?? [void 0, void 0];
669
+ return {
670
+ width: parseInteger(width),
671
+ height: parseInteger(height)
403
672
  };
404
673
  }
674
+ function parseInteger(value) {
675
+ if (value === "" || !value) return void 0;
676
+ const number = Number(value);
677
+ return Number.isInteger(number) ? number : void 0;
678
+ }
679
+ async function streamToUint8Array(stream2) {
680
+ const reader = stream2.getReader();
681
+ const chunks = [];
682
+ let totalLength = 0;
683
+ while (true) {
684
+ const { done, value } = await reader.read();
685
+ if (done) break;
686
+ chunks.push(value);
687
+ totalLength += value.length;
688
+ }
689
+ const result = new Uint8Array(totalLength);
690
+ let offset = 0;
691
+ for (const chunk of chunks) {
692
+ result.set(chunk, offset);
693
+ offset += chunk.length;
694
+ }
695
+ return result;
696
+ }
405
697
 
406
698
  // src/index.ts
407
699
  function createWorkersAI(options) {
@@ -411,28 +703,31 @@ function createWorkersAI(options) {
411
703
  } else {
412
704
  const { accountId, apiKey } = options;
413
705
  binding = {
414
- run: createRun(accountId, apiKey)
706
+ run: createRun({ accountId, apiKey })
415
707
  };
416
708
  }
417
709
  if (!binding) {
418
- throw new Error(
419
- "Either a binding or credentials must be provided."
420
- );
710
+ throw new Error("Either a binding or credentials must be provided.");
421
711
  }
422
712
  const createChatModel = (modelId, settings = {}) => new WorkersAIChatLanguageModel(modelId, settings, {
423
713
  provider: "workersai.chat",
424
714
  binding,
425
715
  gateway: options.gateway
426
716
  });
427
- const provider = function(modelId, settings) {
717
+ const createImageModel = (modelId, settings = {}) => new WorkersAIImageModel(modelId, settings, {
718
+ provider: "workersai.image",
719
+ binding,
720
+ gateway: options.gateway
721
+ });
722
+ const provider = (modelId, settings) => {
428
723
  if (new.target) {
429
- throw new Error(
430
- "The WorkersAI model function cannot be called with the new keyword."
431
- );
724
+ throw new Error("The WorkersAI model function cannot be called with the new keyword.");
432
725
  }
433
726
  return createChatModel(modelId, settings);
434
727
  };
435
728
  provider.chat = createChatModel;
729
+ provider.image = createImageModel;
730
+ provider.imageModel = createImageModel;
436
731
  return provider;
437
732
  }
438
733
  export {