workers-ai-provider 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/index.ts CHANGED
@@ -1,103 +1,105 @@
1
+ import { createRun } from "./utils";
1
2
  import { WorkersAIChatLanguageModel } from "./workersai-chat-language-model";
2
3
  import type { WorkersAIChatSettings } from "./workersai-chat-settings";
3
- import type { TextGenerationModels } from "./workersai-models";
4
- import { createRun } from "./utils";
4
+ import { WorkersAIImageModel } from "./workersai-image-model";
5
+ import type { WorkersAIImageSettings } from "./workersai-image-settings";
6
+ import type { ImageGenerationModels, TextGenerationModels } from "./workersai-models";
5
7
 
6
- export type WorkersAISettings =
7
- ({
8
- /**
9
- * Provide a Cloudflare AI binding.
10
- */
11
- binding: Ai;
8
+ export type WorkersAISettings = (
9
+ | {
10
+ /**
11
+ * Provide a Cloudflare AI binding.
12
+ */
13
+ binding: Ai;
12
14
 
13
- /**
14
- * Credentials must be absent when a binding is given.
15
- */
16
- accountId?: never;
17
- apiKey?: never;
18
- }
19
- | {
20
- /**
21
- * Provide Cloudflare API credentials directly. Must be used if a binding is not specified.
22
- */
23
- accountId: string;
24
- apiKey: string;
25
- /**
26
- * Both binding must be absent if credentials are used directly.
27
- */
28
- binding?: never;
29
- }) & {
30
- /**
31
- * Optionally specify a gateway.
32
- */
33
- gateway?: GatewayOptions;
15
+ /**
16
+ * Credentials must be absent when a binding is given.
17
+ */
18
+ accountId?: never;
19
+ apiKey?: never;
20
+ }
21
+ | {
22
+ /**
23
+ * Provide Cloudflare API credentials directly. Must be used if a binding is not specified.
24
+ */
25
+ accountId: string;
26
+ apiKey: string;
27
+ /**
28
+ * Both binding must be absent if credentials are used directly.
29
+ */
30
+ binding?: never;
31
+ }
32
+ ) & {
33
+ /**
34
+ * Optionally specify a gateway.
35
+ */
36
+ gateway?: GatewayOptions;
34
37
  };
35
38
 
36
39
  export interface WorkersAI {
37
- (
38
- modelId: TextGenerationModels,
39
- settings?: WorkersAIChatSettings
40
- ): WorkersAIChatLanguageModel;
40
+ (modelId: TextGenerationModels, settings?: WorkersAIChatSettings): WorkersAIChatLanguageModel;
41
+ /**
42
+ * Creates a model for text generation.
43
+ **/
44
+ chat(
45
+ modelId: TextGenerationModels,
46
+ settings?: WorkersAIChatSettings,
47
+ ): WorkersAIChatLanguageModel;
41
48
 
42
- /**
43
- * Creates a model for text generation.
44
- **/
45
- chat(
46
- modelId: TextGenerationModels,
47
- settings?: WorkersAIChatSettings
48
- ): WorkersAIChatLanguageModel;
49
+ /**
50
+ * Creates a model for image generation.
51
+ **/
52
+ image(modelId: ImageGenerationModels, settings?: WorkersAIImageSettings): WorkersAIImageModel;
49
53
  }
50
54
 
51
55
  /**
52
56
  * Create a Workers AI provider instance.
53
57
  */
54
58
  export function createWorkersAI(options: WorkersAISettings): WorkersAI {
55
- // Use a binding if one is directly provided. Otherwise use credentials to create
56
- // a `run` method that calls the Cloudflare REST API.
57
- let binding: Ai | undefined;
59
+ // Use a binding if one is directly provided. Otherwise use credentials to create
60
+ // a `run` method that calls the Cloudflare REST API.
61
+ let binding: Ai | undefined;
58
62
 
59
- if (options.binding) {
60
- binding = options.binding;
61
- } else {
62
- const { accountId, apiKey } = options;
63
- binding = {
64
- run: createRun(accountId, apiKey),
65
- } as Ai;
66
- }
63
+ if (options.binding) {
64
+ binding = options.binding;
65
+ } else {
66
+ const { accountId, apiKey } = options;
67
+ binding = {
68
+ run: createRun({ accountId, apiKey }),
69
+ } as Ai;
70
+ }
67
71
 
68
- if (!binding) {
69
- throw new Error(
70
- "Either a binding or credentials must be provided."
71
- );
72
- }
72
+ if (!binding) {
73
+ throw new Error("Either a binding or credentials must be provided.");
74
+ }
73
75
 
74
- /**
75
- * Helper function to create a chat model instance.
76
- */
77
- const createChatModel = (
78
- modelId: TextGenerationModels,
79
- settings: WorkersAIChatSettings = {}
80
- ) =>
81
- new WorkersAIChatLanguageModel(modelId, settings, {
82
- provider: "workersai.chat",
83
- binding,
84
- gateway: options.gateway
85
- });
76
+ const createChatModel = (modelId: TextGenerationModels, settings: WorkersAIChatSettings = {}) =>
77
+ new WorkersAIChatLanguageModel(modelId, settings, {
78
+ provider: "workersai.chat",
79
+ binding,
80
+ gateway: options.gateway,
81
+ });
86
82
 
87
- const provider = function (
88
- modelId: TextGenerationModels,
89
- settings?: WorkersAIChatSettings
90
- ) {
91
- if (new.target) {
92
- throw new Error(
93
- "The WorkersAI model function cannot be called with the new keyword."
94
- );
95
- }
96
- return createChatModel(modelId, settings);
97
- };
83
+ const createImageModel = (
84
+ modelId: ImageGenerationModels,
85
+ settings: WorkersAIImageSettings = {},
86
+ ) =>
87
+ new WorkersAIImageModel(modelId, settings, {
88
+ provider: "workersai.image",
89
+ binding,
90
+ gateway: options.gateway,
91
+ });
98
92
 
99
- provider.chat = createChatModel;
93
+ const provider = (modelId: TextGenerationModels, settings?: WorkersAIChatSettings) => {
94
+ if (new.target) {
95
+ throw new Error("The WorkersAI model function cannot be called with the new keyword.");
96
+ }
97
+ return createChatModel(modelId, settings);
98
+ };
100
99
 
101
- return provider;
102
- }
100
+ provider.chat = createChatModel;
101
+ provider.image = createImageModel;
102
+ provider.imageModel = createImageModel;
103
103
 
104
+ return provider;
105
+ }
@@ -1,17 +1,17 @@
1
1
  import type { LanguageModelV1FinishReason } from "@ai-sdk/provider";
2
2
 
3
3
  export function mapWorkersAIFinishReason(
4
- finishReason: string | null | undefined
4
+ finishReason: string | null | undefined,
5
5
  ): LanguageModelV1FinishReason {
6
- switch (finishReason) {
7
- case "stop":
8
- return "stop";
9
- case "length":
10
- case "model_length":
11
- return "length";
12
- case "tool_calls":
13
- return "tool-calls";
14
- default:
15
- return "other";
16
- }
6
+ switch (finishReason) {
7
+ case "stop":
8
+ return "stop";
9
+ case "length":
10
+ case "model_length":
11
+ return "length";
12
+ case "tool_calls":
13
+ return "tool-calls";
14
+ default:
15
+ return "other";
16
+ }
17
17
  }
@@ -1,11 +1,15 @@
1
- export function mapWorkersAIUsage(output: AiTextGenerationOutput | AiTextToImageOutput) {
2
- const usage = (output as { usage: { prompt_tokens: number, completion_tokens: number} }).usage ?? {
1
+ export function mapWorkersAIUsage(output: AiTextGenerationOutput | AiTextToImageOutput) {
2
+ const usage = (
3
+ output as {
4
+ usage: { prompt_tokens: number; completion_tokens: number };
5
+ }
6
+ ).usage ?? {
3
7
  prompt_tokens: 0,
4
8
  completion_tokens: 0,
5
9
  };
6
10
 
7
11
  return {
8
12
  promptTokens: usage.prompt_tokens,
9
- completionTokens: usage.completion_tokens
10
- }
13
+ completionTokens: usage.completion_tokens,
14
+ };
11
15
  }
package/src/utils.ts CHANGED
@@ -1,64 +1,101 @@
1
1
  /**
2
- * Creates a run method that mimics the Cloudflare Workers AI binding,
3
- * but uses the Cloudflare REST API under the hood.
2
+ * General AI run interface with overloads to handle distinct return types.
4
3
  *
5
- * @param accountId - Your Cloudflare account identifier.
6
- * @param apiKey - Your Cloudflare API token/key with appropriate permissions.
7
- * @returns An function matching `Ai['run']`.
4
+ * The behaviour depends on the combination of parameters:
5
+ * 1. `returnRawResponse: true` => returns the raw Response object.
6
+ * 2. `stream: true` => returns a ReadableStream (if available).
7
+ * 3. Otherwise => returns post-processed AI results.
8
8
  */
9
- export function createRun(accountId: string, apiKey: string): AiRun {
10
- return async <Name extends keyof AiModels>(model: Name, inputs: AiModels[Name]["inputs"], options?: AiOptions | undefined) => {
11
- const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}`;
12
- const body = JSON.stringify(inputs);
9
+ export interface AiRun {
10
+ // (1) Return raw Response if `options.returnRawResponse` is `true`.
11
+ <Name extends keyof AiModels>(
12
+ model: Name,
13
+ inputs: AiModels[Name]["inputs"],
14
+ options: AiOptions & { returnRawResponse: true },
15
+ ): Promise<Response>;
13
16
 
14
- const headers = {
15
- "Content-Type": "application/json",
16
- Authorization: `Bearer ${apiKey}`,
17
- };
17
+ // (2) Return a stream if the input has `stream: true`.
18
+ <Name extends keyof AiModels>(
19
+ model: Name,
20
+ inputs: AiModels[Name]["inputs"] & { stream: true },
21
+ options?: AiOptions,
22
+ ): Promise<ReadableStream<Uint8Array>>;
18
23
 
19
- const response = await fetch(url, {
20
- method: "POST",
21
- headers,
22
- body,
23
- }) as Response;
24
-
25
- if (options?.returnRawResponse) {
26
- return response;
27
- }
24
+ // (3) Return post-processed outputs by default.
25
+ <Name extends keyof AiModels>(
26
+ model: Name,
27
+ inputs: AiModels[Name]["inputs"],
28
+ options?: AiOptions,
29
+ ): Promise<AiModels[Name]["postProcessedOutputs"]>;
30
+ }
28
31
 
29
- if ((inputs as AiTextGenerationInput).stream === true) {
30
- // If there's a stream, return the raw body so the caller can process it
31
- if (response.body) {
32
- return response.body;
33
- }
34
- throw new Error("No readable body available for streaming.");
35
- }
32
+ /**
33
+ * Parameters for configuring the Cloudflare-based AI runner.
34
+ */
35
+ export interface CreateRunConfig {
36
+ /** Your Cloudflare account identifier. */
37
+ accountId: string;
36
38
 
37
- // Otherwise, parse JSON and return the data.result
38
- const data = await response.json<{ result: AiModels[Name]["postProcessedOutputs"] }>();
39
- return data.result;
40
- };
39
+ /** Cloudflare API token/key with appropriate permissions. */
40
+ apiKey: string;
41
41
  }
42
42
 
43
- interface AiRun {
44
- // (1) Return raw response if `options.returnRawResponse` is `true`.
45
- <Name extends keyof AiModels>(
46
- model: Name,
47
- inputs: AiModels[Name]["inputs"],
48
- options: AiOptions & { returnRawResponse: true }
49
- ): Promise<Response>;
43
+ /**
44
+ * Creates a run method that emulates the Cloudflare Workers AI binding,
45
+ * but uses the Cloudflare REST API under the hood. Headers and abort
46
+ * signals are configured at creation time, rather than per-request.
47
+ *
48
+ * @param config An object containing:
49
+ * - `accountId`: Cloudflare account identifier.
50
+ * - `apiKey`: Cloudflare API token/key with suitable permissions.
51
+ * - `headers`: Optional custom headers to merge with defaults.
52
+ * - `signal`: Optional AbortSignal for request cancellation.
53
+ *
54
+ * @returns A function matching the AiRun interface.
55
+ */
56
+ export function createRun(config: CreateRunConfig): AiRun {
57
+ const { accountId, apiKey } = config;
58
+
59
+ // Return the AiRun-compatible function.
60
+ return async function run<Name extends keyof AiModels>(
61
+ model: Name,
62
+ inputs: AiModels[Name]["inputs"],
63
+ options?: AiOptions,
64
+ ): Promise<Response | ReadableStream<Uint8Array> | AiModels[Name]["postProcessedOutputs"]> {
65
+ const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}`;
66
+
67
+ // Merge default and custom headers.
68
+ const headers = {
69
+ "Content-Type": "application/json",
70
+ Authorization: `Bearer ${apiKey}`,
71
+ };
72
+
73
+ const body = JSON.stringify(inputs);
74
+
75
+ // Execute the POST request. The optional AbortSignal is applied here.
76
+ const response = await fetch(url, {
77
+ method: "POST",
78
+ headers,
79
+ body,
80
+ });
81
+
82
+ // (1) If the user explicitly requests the raw Response, return it as-is.
83
+ if (options?.returnRawResponse) {
84
+ return response;
85
+ }
50
86
 
51
- // (2) Return a stream if the input has `stream: true`.
52
- <Name extends keyof AiModels>(
53
- model: Name,
54
- inputs: AiModels[Name]["inputs"] & { stream: true },
55
- options?: AiOptions
56
- ): Promise<ReadableStream<Uint8Array>>;
87
+ // (2) If the AI input requests streaming, return the ReadableStream if available.
88
+ if ((inputs as AiTextGenerationInput).stream === true) {
89
+ if (response.body) {
90
+ return response.body;
91
+ }
92
+ throw new Error("No readable body available for streaming.");
93
+ }
57
94
 
58
- // (3) Return the post-processed outputs by default.
59
- <Name extends keyof AiModels>(
60
- model: Name,
61
- inputs: AiModels[Name]["inputs"],
62
- options?: AiOptions
63
- ): Promise<AiModels[Name]["postProcessedOutputs"]>;
95
+ // (3) In all other cases, parse JSON and return the result field.
96
+ const data = await response.json<{
97
+ result: AiModels[Name]["postProcessedOutputs"];
98
+ }>();
99
+ return data.result;
100
+ };
64
101
  }