workers-ai-provider 0.1.3 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +20 -2
- package/dist/index.d.ts +55 -21
- package/dist/index.js +363 -26
- package/dist/index.js.map +1 -1
- package/package.json +40 -44
- package/src/convert-to-workersai-chat-messages.ts +94 -97
- package/src/index.ts +89 -52
- package/src/map-workersai-finish-reason.ts +12 -12
- package/src/map-workersai-usage.ts +8 -4
- package/src/utils.ts +101 -0
- package/src/workersai-chat-language-model.ts +313 -325
- package/src/workersai-chat-prompt.ts +18 -18
- package/src/workersai-chat-settings.ts +11 -11
- package/src/workersai-error.ts +10 -13
- package/src/workersai-image-config.ts +5 -0
- package/src/workersai-image-model.ts +114 -0
- package/src/workersai-image-settings.ts +3 -0
- package/src/workersai-models.ts +7 -2
package/README.md
CHANGED
@@ -22,8 +22,7 @@ binding = "AI"
|
|
22
22
|
Then in your Worker, import the factory function and create a new AI provider:
|
23
23
|
|
24
24
|
```ts
|
25
|
-
|
26
|
-
import { createWorkersAI } from "workers-ai-provider";
|
25
|
+
import { createWorkersAI } from "../../../packages/workers-ai-provider/src";
|
27
26
|
import { streamText } from "ai";
|
28
27
|
|
29
28
|
type Env = {
|
@@ -58,6 +57,25 @@ export default {
|
|
58
57
|
};
|
59
58
|
```
|
60
59
|
|
60
|
+
You can also use your Cloudflare credentials to create the provider, for example if you want to use Cloudflare AI outside of the Worker environment. For example, here is how you can use Cloudflare AI in a Node script:
|
61
|
+
|
62
|
+
```js
|
63
|
+
const workersai = createWorkersAI({
|
64
|
+
accountId: process.env.CLOUDFLARE_ACCOUNT_ID,
|
65
|
+
apiKey: process.env.CLOUDFLARE_API_KEY
|
66
|
+
});
|
67
|
+
|
68
|
+
const text = await streamText({
|
69
|
+
model: workersai("@cf/meta/llama-2-7b-chat-int8"),
|
70
|
+
messages: [
|
71
|
+
{
|
72
|
+
role: "user",
|
73
|
+
content: "Write an essay about hello world",
|
74
|
+
},
|
75
|
+
],
|
76
|
+
});
|
77
|
+
```
|
78
|
+
|
61
79
|
For more info, refer to the documentation of the [Vercel AI SDK](https://sdk.vercel.ai/).
|
62
80
|
|
63
81
|
### Credits
|
package/dist/index.d.ts
CHANGED
@@ -1,11 +1,10 @@
|
|
1
|
-
import { LanguageModelV1 } from '@ai-sdk/provider';
|
1
|
+
import { LanguageModelV1, ImageModelV1 } from '@ai-sdk/provider';
|
2
2
|
|
3
3
|
interface WorkersAIChatSettings {
|
4
4
|
/**
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
*/
|
5
|
+
* Whether to inject a safety prompt before all conversations.
|
6
|
+
* Defaults to `false`.
|
7
|
+
*/
|
9
8
|
safePrompt?: boolean;
|
10
9
|
/**
|
11
10
|
* Optionally set Cloudflare AI Gateway options.
|
@@ -18,6 +17,7 @@ interface WorkersAIChatSettings {
|
|
18
17
|
* The names of the BaseAiTextGeneration models.
|
19
18
|
*/
|
20
19
|
type TextGenerationModels = Exclude<value2key<AiModels, BaseAiTextGeneration>, value2key<AiModels, BaseAiTextToImage>>;
|
20
|
+
type ImageGenerationModels = value2key<AiModels, BaseAiTextToImage>;
|
21
21
|
type value2key<T, V> = {
|
22
22
|
[K in keyof T]: T[K] extends V ? K : never;
|
23
23
|
}[keyof T];
|
@@ -40,33 +40,67 @@ declare class WorkersAIChatLanguageModel implements LanguageModelV1 {
|
|
40
40
|
doStream(options: Parameters<LanguageModelV1["doStream"]>[0]): Promise<Awaited<ReturnType<LanguageModelV1["doStream"]>>>;
|
41
41
|
}
|
42
42
|
|
43
|
+
type WorkersAIImageConfig = {
|
44
|
+
provider: string;
|
45
|
+
binding: Ai;
|
46
|
+
gateway?: GatewayOptions;
|
47
|
+
};
|
48
|
+
|
49
|
+
type WorkersAIImageSettings = {
|
50
|
+
maxImagesPerCall?: number;
|
51
|
+
};
|
52
|
+
|
53
|
+
declare class WorkersAIImageModel implements ImageModelV1 {
|
54
|
+
readonly modelId: ImageGenerationModels;
|
55
|
+
readonly settings: WorkersAIImageSettings;
|
56
|
+
readonly config: WorkersAIImageConfig;
|
57
|
+
readonly specificationVersion = "v1";
|
58
|
+
get maxImagesPerCall(): number;
|
59
|
+
get provider(): string;
|
60
|
+
constructor(modelId: ImageGenerationModels, settings: WorkersAIImageSettings, config: WorkersAIImageConfig);
|
61
|
+
doGenerate({ prompt, n, size, aspectRatio, seed, }: Parameters<ImageModelV1["doGenerate"]>[0]): Promise<Awaited<ReturnType<ImageModelV1["doGenerate"]>>>;
|
62
|
+
}
|
63
|
+
|
64
|
+
type WorkersAISettings = ({
|
65
|
+
/**
|
66
|
+
* Provide a Cloudflare AI binding.
|
67
|
+
*/
|
68
|
+
binding: Ai;
|
69
|
+
/**
|
70
|
+
* Credentials must be absent when a binding is given.
|
71
|
+
*/
|
72
|
+
accountId?: never;
|
73
|
+
apiKey?: never;
|
74
|
+
} | {
|
75
|
+
/**
|
76
|
+
* Provide Cloudflare API credentials directly. Must be used if a binding is not specified.
|
77
|
+
*/
|
78
|
+
accountId: string;
|
79
|
+
apiKey: string;
|
80
|
+
/**
|
81
|
+
* Both binding must be absent if credentials are used directly.
|
82
|
+
*/
|
83
|
+
binding?: never;
|
84
|
+
}) & {
|
85
|
+
/**
|
86
|
+
* Optionally specify a gateway.
|
87
|
+
*/
|
88
|
+
gateway?: GatewayOptions;
|
89
|
+
};
|
43
90
|
interface WorkersAI {
|
44
91
|
(modelId: TextGenerationModels, settings?: WorkersAIChatSettings): WorkersAIChatLanguageModel;
|
45
92
|
/**
|
46
93
|
* Creates a model for text generation.
|
47
94
|
**/
|
48
95
|
chat(modelId: TextGenerationModels, settings?: WorkersAIChatSettings): WorkersAIChatLanguageModel;
|
49
|
-
}
|
50
|
-
interface WorkersAISettings {
|
51
96
|
/**
|
52
|
-
*
|
53
|
-
* You can set up an AI bindings in your Workers project
|
54
|
-
* by adding the following this to `wrangler.toml`:
|
55
|
-
|
56
|
-
```toml
|
57
|
-
[ai]
|
58
|
-
binding = "AI"
|
59
|
-
```
|
97
|
+
* Creates a model for image generation.
|
60
98
|
**/
|
61
|
-
|
62
|
-
/**
|
63
|
-
* Optionally set Cloudflare AI Gateway options.
|
64
|
-
*/
|
65
|
-
gateway?: GatewayOptions;
|
99
|
+
image(modelId: ImageGenerationModels, settings?: WorkersAIImageSettings): WorkersAIImageModel;
|
66
100
|
}
|
67
101
|
/**
|
68
102
|
* Create a Workers AI provider instance.
|
69
|
-
|
103
|
+
*/
|
70
104
|
declare function createWorkersAI(options: WorkersAISettings): WorkersAI;
|
71
105
|
|
72
106
|
export { type WorkersAI, type WorkersAISettings, createWorkersAI };
|
package/dist/index.js
CHANGED
@@ -1,17 +1,161 @@
|
|
1
1
|
var __defProp = Object.defineProperty;
|
2
|
+
var __typeError = (msg) => {
|
3
|
+
throw TypeError(msg);
|
4
|
+
};
|
2
5
|
var __defNormalProp = (obj, key, value) => key in obj ? __defProp(obj, key, { enumerable: true, configurable: true, writable: true, value }) : obj[key] = value;
|
3
6
|
var __publicField = (obj, key, value) => __defNormalProp(obj, typeof key !== "symbol" ? key + "" : key, value);
|
7
|
+
var __accessCheck = (obj, member, msg) => member.has(obj) || __typeError("Cannot " + msg);
|
8
|
+
var __privateGet = (obj, member, getter) => (__accessCheck(obj, member, "read from private field"), getter ? getter.call(obj) : member.get(obj));
|
9
|
+
var __privateAdd = (obj, member, value) => member.has(obj) ? __typeError("Cannot add the same private member more than once") : member instanceof WeakSet ? member.add(obj) : member.set(obj, value);
|
10
|
+
var __privateSet = (obj, member, value, setter) => (__accessCheck(obj, member, "write to private field"), setter ? setter.call(obj, value) : member.set(obj, value), value);
|
4
11
|
|
5
|
-
// src/
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
12
|
+
// src/utils.ts
|
13
|
+
function createRun(config) {
|
14
|
+
const { accountId, apiKey } = config;
|
15
|
+
return async function run(model, inputs, options) {
|
16
|
+
const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}`;
|
17
|
+
const headers = {
|
18
|
+
"Content-Type": "application/json",
|
19
|
+
Authorization: `Bearer ${apiKey}`
|
20
|
+
};
|
21
|
+
const body = JSON.stringify(inputs);
|
22
|
+
const response = await fetch(url, {
|
23
|
+
method: "POST",
|
24
|
+
headers,
|
25
|
+
body
|
26
|
+
});
|
27
|
+
if (options?.returnRawResponse) {
|
28
|
+
return response;
|
29
|
+
}
|
30
|
+
if (inputs.stream === true) {
|
31
|
+
if (response.body) {
|
32
|
+
return response.body;
|
33
|
+
}
|
34
|
+
throw new Error("No readable body available for streaming.");
|
35
|
+
}
|
36
|
+
const data = await response.json();
|
37
|
+
return data.result;
|
38
|
+
};
|
39
|
+
}
|
40
|
+
|
41
|
+
// ../../node_modules/@ai-sdk/provider/dist/index.mjs
|
42
|
+
var marker = "vercel.ai.error";
|
43
|
+
var symbol = Symbol.for(marker);
|
44
|
+
var _a;
|
45
|
+
var _AISDKError = class _AISDKError2 extends Error {
|
46
|
+
/**
|
47
|
+
* Creates an AI SDK Error.
|
48
|
+
*
|
49
|
+
* @param {Object} params - The parameters for creating the error.
|
50
|
+
* @param {string} params.name - The name of the error.
|
51
|
+
* @param {string} params.message - The error message.
|
52
|
+
* @param {unknown} [params.cause] - The underlying cause of the error.
|
53
|
+
*/
|
54
|
+
constructor({
|
55
|
+
name: name14,
|
56
|
+
message,
|
57
|
+
cause
|
58
|
+
}) {
|
59
|
+
super(message);
|
60
|
+
this[_a] = true;
|
61
|
+
this.name = name14;
|
62
|
+
this.cause = cause;
|
63
|
+
}
|
64
|
+
/**
|
65
|
+
* Checks if the given error is an AI SDK Error.
|
66
|
+
* @param {unknown} error - The error to check.
|
67
|
+
* @returns {boolean} True if the error is an AI SDK Error, false otherwise.
|
68
|
+
*/
|
69
|
+
static isInstance(error) {
|
70
|
+
return _AISDKError2.hasMarker(error, marker);
|
71
|
+
}
|
72
|
+
static hasMarker(error, marker15) {
|
73
|
+
const markerSymbol = Symbol.for(marker15);
|
74
|
+
return error != null && typeof error === "object" && markerSymbol in error && typeof error[markerSymbol] === "boolean" && error[markerSymbol] === true;
|
75
|
+
}
|
76
|
+
};
|
77
|
+
_a = symbol;
|
78
|
+
var AISDKError = _AISDKError;
|
79
|
+
var name = "AI_APICallError";
|
80
|
+
var marker2 = `vercel.ai.error.${name}`;
|
81
|
+
var symbol2 = Symbol.for(marker2);
|
82
|
+
var _a2;
|
83
|
+
_a2 = symbol2;
|
84
|
+
var name2 = "AI_EmptyResponseBodyError";
|
85
|
+
var marker3 = `vercel.ai.error.${name2}`;
|
86
|
+
var symbol3 = Symbol.for(marker3);
|
87
|
+
var _a3;
|
88
|
+
_a3 = symbol3;
|
89
|
+
var name3 = "AI_InvalidArgumentError";
|
90
|
+
var marker4 = `vercel.ai.error.${name3}`;
|
91
|
+
var symbol4 = Symbol.for(marker4);
|
92
|
+
var _a4;
|
93
|
+
_a4 = symbol4;
|
94
|
+
var name4 = "AI_InvalidPromptError";
|
95
|
+
var marker5 = `vercel.ai.error.${name4}`;
|
96
|
+
var symbol5 = Symbol.for(marker5);
|
97
|
+
var _a5;
|
98
|
+
_a5 = symbol5;
|
99
|
+
var name5 = "AI_InvalidResponseDataError";
|
100
|
+
var marker6 = `vercel.ai.error.${name5}`;
|
101
|
+
var symbol6 = Symbol.for(marker6);
|
102
|
+
var _a6;
|
103
|
+
_a6 = symbol6;
|
104
|
+
var name6 = "AI_JSONParseError";
|
105
|
+
var marker7 = `vercel.ai.error.${name6}`;
|
106
|
+
var symbol7 = Symbol.for(marker7);
|
107
|
+
var _a7;
|
108
|
+
_a7 = symbol7;
|
109
|
+
var name7 = "AI_LoadAPIKeyError";
|
110
|
+
var marker8 = `vercel.ai.error.${name7}`;
|
111
|
+
var symbol8 = Symbol.for(marker8);
|
112
|
+
var _a8;
|
113
|
+
_a8 = symbol8;
|
114
|
+
var name8 = "AI_LoadSettingError";
|
115
|
+
var marker9 = `vercel.ai.error.${name8}`;
|
116
|
+
var symbol9 = Symbol.for(marker9);
|
117
|
+
var _a9;
|
118
|
+
_a9 = symbol9;
|
119
|
+
var name9 = "AI_NoContentGeneratedError";
|
120
|
+
var marker10 = `vercel.ai.error.${name9}`;
|
121
|
+
var symbol10 = Symbol.for(marker10);
|
122
|
+
var _a10;
|
123
|
+
_a10 = symbol10;
|
124
|
+
var name10 = "AI_NoSuchModelError";
|
125
|
+
var marker11 = `vercel.ai.error.${name10}`;
|
126
|
+
var symbol11 = Symbol.for(marker11);
|
127
|
+
var _a11;
|
128
|
+
_a11 = symbol11;
|
129
|
+
var name11 = "AI_TooManyEmbeddingValuesForCallError";
|
130
|
+
var marker12 = `vercel.ai.error.${name11}`;
|
131
|
+
var symbol12 = Symbol.for(marker12);
|
132
|
+
var _a12;
|
133
|
+
_a12 = symbol12;
|
134
|
+
var name12 = "AI_TypeValidationError";
|
135
|
+
var marker13 = `vercel.ai.error.${name12}`;
|
136
|
+
var symbol13 = Symbol.for(marker13);
|
137
|
+
var _a13;
|
138
|
+
_a13 = symbol13;
|
139
|
+
var name13 = "AI_UnsupportedFunctionalityError";
|
140
|
+
var marker14 = `vercel.ai.error.${name13}`;
|
141
|
+
var symbol14 = Symbol.for(marker14);
|
142
|
+
var _a14;
|
143
|
+
var UnsupportedFunctionalityError = class extends AISDKError {
|
144
|
+
constructor({
|
145
|
+
functionality,
|
146
|
+
message = `'${functionality}' functionality not supported.`
|
147
|
+
}) {
|
148
|
+
super({ name: name13, message });
|
149
|
+
this[_a14] = true;
|
150
|
+
this.functionality = functionality;
|
151
|
+
}
|
152
|
+
static isInstance(error) {
|
153
|
+
return AISDKError.hasMarker(error, marker14);
|
154
|
+
}
|
155
|
+
};
|
156
|
+
_a14 = symbol14;
|
10
157
|
|
11
158
|
// src/convert-to-workersai-chat-messages.ts
|
12
|
-
import {
|
13
|
-
UnsupportedFunctionalityError
|
14
|
-
} from "@ai-sdk/provider";
|
15
159
|
function convertToWorkersAIChatMessages(prompt) {
|
16
160
|
const messages = [];
|
17
161
|
for (const { role, content } of prompt) {
|
@@ -48,7 +192,10 @@ function convertToWorkersAIChatMessages(prompt) {
|
|
48
192
|
break;
|
49
193
|
}
|
50
194
|
case "tool-call": {
|
51
|
-
text = JSON.stringify({
|
195
|
+
text = JSON.stringify({
|
196
|
+
name: part.toolName,
|
197
|
+
parameters: part.args
|
198
|
+
});
|
52
199
|
toolCalls.push({
|
53
200
|
id: part.toolCallId,
|
54
201
|
type: "function",
|
@@ -68,10 +215,10 @@ function convertToWorkersAIChatMessages(prompt) {
|
|
68
215
|
messages.push({
|
69
216
|
role: "assistant",
|
70
217
|
content: text,
|
71
|
-
tool_calls: toolCalls.length > 0 ? toolCalls.map(({ function: { name, arguments: args } }) => ({
|
218
|
+
tool_calls: toolCalls.length > 0 ? toolCalls.map(({ function: { name: name14, arguments: args } }) => ({
|
72
219
|
id: "null",
|
73
220
|
type: "function",
|
74
|
-
function: { name, arguments: args }
|
221
|
+
function: { name: name14, arguments: args }
|
75
222
|
})) : void 0
|
76
223
|
});
|
77
224
|
break;
|
@@ -95,8 +242,98 @@ function convertToWorkersAIChatMessages(prompt) {
|
|
95
242
|
return messages;
|
96
243
|
}
|
97
244
|
|
98
|
-
//
|
99
|
-
|
245
|
+
// ../../node_modules/fetch-event-stream/esm/deps/jsr.io/@std/streams/0.221.0/text_line_stream.js
|
246
|
+
var _currentLine;
|
247
|
+
var TextLineStream = class extends TransformStream {
|
248
|
+
/** Constructs a new instance. */
|
249
|
+
constructor(options = { allowCR: false }) {
|
250
|
+
super({
|
251
|
+
transform: (chars, controller) => {
|
252
|
+
chars = __privateGet(this, _currentLine) + chars;
|
253
|
+
while (true) {
|
254
|
+
const lfIndex = chars.indexOf("\n");
|
255
|
+
const crIndex = options.allowCR ? chars.indexOf("\r") : -1;
|
256
|
+
if (crIndex !== -1 && crIndex !== chars.length - 1 && (lfIndex === -1 || lfIndex - 1 > crIndex)) {
|
257
|
+
controller.enqueue(chars.slice(0, crIndex));
|
258
|
+
chars = chars.slice(crIndex + 1);
|
259
|
+
continue;
|
260
|
+
}
|
261
|
+
if (lfIndex === -1)
|
262
|
+
break;
|
263
|
+
const endIndex = chars[lfIndex - 1] === "\r" ? lfIndex - 1 : lfIndex;
|
264
|
+
controller.enqueue(chars.slice(0, endIndex));
|
265
|
+
chars = chars.slice(lfIndex + 1);
|
266
|
+
}
|
267
|
+
__privateSet(this, _currentLine, chars);
|
268
|
+
},
|
269
|
+
flush: (controller) => {
|
270
|
+
if (__privateGet(this, _currentLine) === "")
|
271
|
+
return;
|
272
|
+
const currentLine = options.allowCR && __privateGet(this, _currentLine).endsWith("\r") ? __privateGet(this, _currentLine).slice(0, -1) : __privateGet(this, _currentLine);
|
273
|
+
controller.enqueue(currentLine);
|
274
|
+
}
|
275
|
+
});
|
276
|
+
__privateAdd(this, _currentLine, "");
|
277
|
+
}
|
278
|
+
};
|
279
|
+
_currentLine = new WeakMap();
|
280
|
+
|
281
|
+
// ../../node_modules/fetch-event-stream/esm/utils.js
|
282
|
+
function stream(input) {
|
283
|
+
let decoder = new TextDecoderStream();
|
284
|
+
let split2 = new TextLineStream({ allowCR: true });
|
285
|
+
return input.pipeThrough(decoder).pipeThrough(split2);
|
286
|
+
}
|
287
|
+
function split(input) {
|
288
|
+
let rgx = /[:]\s*/;
|
289
|
+
let match = rgx.exec(input);
|
290
|
+
let idx = match && match.index;
|
291
|
+
if (idx) {
|
292
|
+
return [
|
293
|
+
input.substring(0, idx),
|
294
|
+
input.substring(idx + match[0].length)
|
295
|
+
];
|
296
|
+
}
|
297
|
+
}
|
298
|
+
|
299
|
+
// ../../node_modules/fetch-event-stream/esm/mod.js
|
300
|
+
async function* events(res, signal) {
|
301
|
+
if (!res.body)
|
302
|
+
return;
|
303
|
+
let iter = stream(res.body);
|
304
|
+
let line, reader = iter.getReader();
|
305
|
+
let event;
|
306
|
+
for (; ; ) {
|
307
|
+
if (signal && signal.aborted) {
|
308
|
+
return reader.cancel();
|
309
|
+
}
|
310
|
+
line = await reader.read();
|
311
|
+
if (line.done)
|
312
|
+
return;
|
313
|
+
if (!line.value) {
|
314
|
+
if (event)
|
315
|
+
yield event;
|
316
|
+
event = void 0;
|
317
|
+
continue;
|
318
|
+
}
|
319
|
+
let [field, value] = split(line.value) || [];
|
320
|
+
if (!field)
|
321
|
+
continue;
|
322
|
+
if (field === "data") {
|
323
|
+
event || (event = {});
|
324
|
+
event[field] = event[field] ? event[field] + "\n" + value : value;
|
325
|
+
} else if (field === "event") {
|
326
|
+
event || (event = {});
|
327
|
+
event[field] = value;
|
328
|
+
} else if (field === "id") {
|
329
|
+
event || (event = {});
|
330
|
+
event[field] = +value || value;
|
331
|
+
} else if (field === "retry") {
|
332
|
+
event || (event = {});
|
333
|
+
event[field] = +value || void 0;
|
334
|
+
}
|
335
|
+
}
|
336
|
+
}
|
100
337
|
|
101
338
|
// src/map-workersai-usage.ts
|
102
339
|
function mapWorkersAIUsage(output) {
|
@@ -195,7 +432,7 @@ var WorkersAIChatLanguageModel = class {
|
|
195
432
|
// @ts-expect-error - this is unreachable code
|
196
433
|
// TODO: fixme
|
197
434
|
case "object-grammar": {
|
198
|
-
throw new
|
435
|
+
throw new UnsupportedFunctionalityError({
|
199
436
|
functionality: "object-grammar mode"
|
200
437
|
});
|
201
438
|
}
|
@@ -325,10 +562,6 @@ var WorkersAIChatLanguageModel = class {
|
|
325
562
|
};
|
326
563
|
}
|
327
564
|
};
|
328
|
-
var workersAIChatResponseSchema = z.object({
|
329
|
-
response: z.string()
|
330
|
-
});
|
331
|
-
var workersAIChatChunkSchema = z.instanceof(Uint8Array);
|
332
565
|
function prepareToolsAndToolChoice(mode) {
|
333
566
|
const tools = mode.tools?.length ? mode.tools : void 0;
|
334
567
|
if (tools == null) {
|
@@ -360,9 +593,7 @@ function prepareToolsAndToolChoice(mode) {
|
|
360
593
|
// so we filter the tools and force the tool choice through 'any'
|
361
594
|
case "tool":
|
362
595
|
return {
|
363
|
-
tools: mappedTools.filter(
|
364
|
-
(tool) => tool.function.name === toolChoice.toolName
|
365
|
-
),
|
596
|
+
tools: mappedTools.filter((tool) => tool.function.name === toolChoice.toolName),
|
366
597
|
tool_choice: "any"
|
367
598
|
};
|
368
599
|
default: {
|
@@ -375,22 +606,128 @@ function lastMessageWasUser(messages) {
|
|
375
606
|
return messages.length > 0 && messages[messages.length - 1].role === "user";
|
376
607
|
}
|
377
608
|
|
609
|
+
// src/workersai-image-model.ts
|
610
|
+
var WorkersAIImageModel = class {
|
611
|
+
constructor(modelId, settings, config) {
|
612
|
+
this.modelId = modelId;
|
613
|
+
this.settings = settings;
|
614
|
+
this.config = config;
|
615
|
+
__publicField(this, "specificationVersion", "v1");
|
616
|
+
}
|
617
|
+
get maxImagesPerCall() {
|
618
|
+
return this.settings.maxImagesPerCall ?? 1;
|
619
|
+
}
|
620
|
+
get provider() {
|
621
|
+
return this.config.provider;
|
622
|
+
}
|
623
|
+
async doGenerate({
|
624
|
+
prompt,
|
625
|
+
n,
|
626
|
+
size,
|
627
|
+
aspectRatio,
|
628
|
+
seed
|
629
|
+
// headers,
|
630
|
+
// abortSignal,
|
631
|
+
}) {
|
632
|
+
const { width, height } = getDimensionsFromSizeString(size);
|
633
|
+
const warnings = [];
|
634
|
+
if (aspectRatio != null) {
|
635
|
+
warnings.push({
|
636
|
+
type: "unsupported-setting",
|
637
|
+
setting: "aspectRatio",
|
638
|
+
details: "This model does not support aspect ratio. Use `size` instead."
|
639
|
+
});
|
640
|
+
}
|
641
|
+
const generateImage = async () => {
|
642
|
+
const outputStream = await this.config.binding.run(
|
643
|
+
this.modelId,
|
644
|
+
{
|
645
|
+
prompt,
|
646
|
+
seed,
|
647
|
+
width,
|
648
|
+
height
|
649
|
+
}
|
650
|
+
);
|
651
|
+
return streamToUint8Array(outputStream);
|
652
|
+
};
|
653
|
+
const images = await Promise.all(
|
654
|
+
Array.from({ length: n }, () => generateImage())
|
655
|
+
);
|
656
|
+
return {
|
657
|
+
images,
|
658
|
+
warnings,
|
659
|
+
response: {
|
660
|
+
timestamp: /* @__PURE__ */ new Date(),
|
661
|
+
modelId: this.modelId,
|
662
|
+
headers: {}
|
663
|
+
}
|
664
|
+
};
|
665
|
+
}
|
666
|
+
};
|
667
|
+
function getDimensionsFromSizeString(size) {
|
668
|
+
const [width, height] = size?.split("x") ?? [void 0, void 0];
|
669
|
+
return {
|
670
|
+
width: parseInteger(width),
|
671
|
+
height: parseInteger(height)
|
672
|
+
};
|
673
|
+
}
|
674
|
+
function parseInteger(value) {
|
675
|
+
if (value === "" || !value) return void 0;
|
676
|
+
const number = Number(value);
|
677
|
+
return Number.isInteger(number) ? number : void 0;
|
678
|
+
}
|
679
|
+
async function streamToUint8Array(stream2) {
|
680
|
+
const reader = stream2.getReader();
|
681
|
+
const chunks = [];
|
682
|
+
let totalLength = 0;
|
683
|
+
while (true) {
|
684
|
+
const { done, value } = await reader.read();
|
685
|
+
if (done) break;
|
686
|
+
chunks.push(value);
|
687
|
+
totalLength += value.length;
|
688
|
+
}
|
689
|
+
const result = new Uint8Array(totalLength);
|
690
|
+
let offset = 0;
|
691
|
+
for (const chunk of chunks) {
|
692
|
+
result.set(chunk, offset);
|
693
|
+
offset += chunk.length;
|
694
|
+
}
|
695
|
+
return result;
|
696
|
+
}
|
697
|
+
|
378
698
|
// src/index.ts
|
379
699
|
function createWorkersAI(options) {
|
700
|
+
let binding;
|
701
|
+
if (options.binding) {
|
702
|
+
binding = options.binding;
|
703
|
+
} else {
|
704
|
+
const { accountId, apiKey } = options;
|
705
|
+
binding = {
|
706
|
+
run: createRun({ accountId, apiKey })
|
707
|
+
};
|
708
|
+
}
|
709
|
+
if (!binding) {
|
710
|
+
throw new Error("Either a binding or credentials must be provided.");
|
711
|
+
}
|
380
712
|
const createChatModel = (modelId, settings = {}) => new WorkersAIChatLanguageModel(modelId, settings, {
|
381
713
|
provider: "workersai.chat",
|
382
|
-
binding
|
714
|
+
binding,
|
715
|
+
gateway: options.gateway
|
716
|
+
});
|
717
|
+
const createImageModel = (modelId, settings = {}) => new WorkersAIImageModel(modelId, settings, {
|
718
|
+
provider: "workersai.image",
|
719
|
+
binding,
|
383
720
|
gateway: options.gateway
|
384
721
|
});
|
385
|
-
const provider =
|
722
|
+
const provider = (modelId, settings) => {
|
386
723
|
if (new.target) {
|
387
|
-
throw new Error(
|
388
|
-
"The WorkersAI model function cannot be called with the new keyword."
|
389
|
-
);
|
724
|
+
throw new Error("The WorkersAI model function cannot be called with the new keyword.");
|
390
725
|
}
|
391
726
|
return createChatModel(modelId, settings);
|
392
727
|
};
|
393
728
|
provider.chat = createChatModel;
|
729
|
+
provider.image = createImageModel;
|
730
|
+
provider.imageModel = createImageModel;
|
394
731
|
return provider;
|
395
732
|
}
|
396
733
|
export {
|