workers-ai-provider 0.7.5 → 2.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +30 -0
- package/dist/chunk-H3ZBSMAH.js +300 -0
- package/dist/chunk-H3ZBSMAH.js.map +1 -0
- package/dist/chunk-LOLDRYLH.js +35 -0
- package/dist/chunk-LOLDRYLH.js.map +1 -0
- package/dist/index.d.ts +17 -15
- package/dist/index.js +12666 -259
- package/dist/index.js.map +1 -1
- package/dist/token-4SRL5WJU.js +63 -0
- package/dist/token-4SRL5WJU.js.map +1 -0
- package/dist/token-util-24B4MTMT.js +6 -0
- package/dist/token-util-24B4MTMT.js.map +1 -0
- package/package.json +5 -5
- package/src/autorag-chat-language-model.ts +76 -54
- package/src/convert-to-workersai-chat-messages.ts +14 -12
- package/src/index.ts +9 -5
- package/src/map-workersai-finish-reason.ts +2 -2
- package/src/map-workersai-usage.ts +3 -2
- package/src/streaming.ts +53 -15
- package/src/utils.ts +14 -21
- package/src/workers-ai-embedding-model.ts +11 -9
- package/src/workersai-chat-language-model.ts +139 -69
- package/src/workersai-error.ts +1 -1
- package/src/workersai-image-model.ts +6 -6
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import {
|
|
2
|
+
require_token_error,
|
|
3
|
+
require_token_util
|
|
4
|
+
} from "./chunk-H3ZBSMAH.js";
|
|
5
|
+
import {
|
|
6
|
+
__commonJS
|
|
7
|
+
} from "./chunk-LOLDRYLH.js";
|
|
8
|
+
|
|
9
|
+
// ../../node_modules/.pnpm/@vercel+oidc@3.0.5/node_modules/@vercel/oidc/dist/token.js
|
|
10
|
+
var require_token = __commonJS({
|
|
11
|
+
"../../node_modules/.pnpm/@vercel+oidc@3.0.5/node_modules/@vercel/oidc/dist/token.js"(exports, module) {
|
|
12
|
+
var __defProp = Object.defineProperty;
|
|
13
|
+
var __getOwnPropDesc = Object.getOwnPropertyDescriptor;
|
|
14
|
+
var __getOwnPropNames = Object.getOwnPropertyNames;
|
|
15
|
+
var __hasOwnProp = Object.prototype.hasOwnProperty;
|
|
16
|
+
var __export = (target, all) => {
|
|
17
|
+
for (var name in all)
|
|
18
|
+
__defProp(target, name, { get: all[name], enumerable: true });
|
|
19
|
+
};
|
|
20
|
+
var __copyProps = (to, from, except, desc) => {
|
|
21
|
+
if (from && typeof from === "object" || typeof from === "function") {
|
|
22
|
+
for (let key of __getOwnPropNames(from))
|
|
23
|
+
if (!__hasOwnProp.call(to, key) && key !== except)
|
|
24
|
+
__defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable });
|
|
25
|
+
}
|
|
26
|
+
return to;
|
|
27
|
+
};
|
|
28
|
+
var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: true }), mod);
|
|
29
|
+
var token_exports = {};
|
|
30
|
+
__export(token_exports, {
|
|
31
|
+
refreshToken: () => refreshToken
|
|
32
|
+
});
|
|
33
|
+
module.exports = __toCommonJS(token_exports);
|
|
34
|
+
var import_token_error = require_token_error();
|
|
35
|
+
var import_token_util = require_token_util();
|
|
36
|
+
async function refreshToken() {
|
|
37
|
+
const { projectId, teamId } = (0, import_token_util.findProjectInfo)();
|
|
38
|
+
let maybeToken = (0, import_token_util.loadToken)(projectId);
|
|
39
|
+
if (!maybeToken || (0, import_token_util.isExpired)((0, import_token_util.getTokenPayload)(maybeToken.token))) {
|
|
40
|
+
const authToken = (0, import_token_util.getVercelCliToken)();
|
|
41
|
+
if (!authToken) {
|
|
42
|
+
throw new import_token_error.VercelOidcTokenError(
|
|
43
|
+
"Failed to refresh OIDC token: login to vercel cli"
|
|
44
|
+
);
|
|
45
|
+
}
|
|
46
|
+
if (!projectId) {
|
|
47
|
+
throw new import_token_error.VercelOidcTokenError(
|
|
48
|
+
"Failed to refresh OIDC token: project id not found"
|
|
49
|
+
);
|
|
50
|
+
}
|
|
51
|
+
maybeToken = await (0, import_token_util.getVercelOidcToken)(authToken, projectId, teamId);
|
|
52
|
+
if (!maybeToken) {
|
|
53
|
+
throw new import_token_error.VercelOidcTokenError("Failed to refresh OIDC token");
|
|
54
|
+
}
|
|
55
|
+
(0, import_token_util.saveToken)(maybeToken, projectId);
|
|
56
|
+
}
|
|
57
|
+
process.env.VERCEL_OIDC_TOKEN = maybeToken.token;
|
|
58
|
+
return;
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
});
|
|
62
|
+
export default require_token();
|
|
63
|
+
//# sourceMappingURL=token-4SRL5WJU.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"sources":["../../../node_modules/.pnpm/@vercel+oidc@3.0.5/node_modules/@vercel/oidc/dist/token.js"],"sourcesContent":["\"use strict\";\nvar __defProp = Object.defineProperty;\nvar __getOwnPropDesc = Object.getOwnPropertyDescriptor;\nvar __getOwnPropNames = Object.getOwnPropertyNames;\nvar __hasOwnProp = Object.prototype.hasOwnProperty;\nvar __export = (target, all) => {\n for (var name in all)\n __defProp(target, name, { get: all[name], enumerable: true });\n};\nvar __copyProps = (to, from, except, desc) => {\n if (from && typeof from === \"object\" || typeof from === \"function\") {\n for (let key of __getOwnPropNames(from))\n if (!__hasOwnProp.call(to, key) && key !== except)\n __defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable });\n }\n return to;\n};\nvar __toCommonJS = (mod) => __copyProps(__defProp({}, \"__esModule\", { value: true }), mod);\nvar token_exports = {};\n__export(token_exports, {\n refreshToken: () => refreshToken\n});\nmodule.exports = __toCommonJS(token_exports);\nvar import_token_error = require(\"./token-error\");\nvar import_token_util = require(\"./token-util\");\nasync function refreshToken() {\n const { projectId, teamId } = (0, import_token_util.findProjectInfo)();\n let maybeToken = (0, import_token_util.loadToken)(projectId);\n if (!maybeToken || (0, import_token_util.isExpired)((0, import_token_util.getTokenPayload)(maybeToken.token))) {\n const authToken = (0, import_token_util.getVercelCliToken)();\n if (!authToken) {\n throw new import_token_error.VercelOidcTokenError(\n \"Failed to refresh OIDC token: login to vercel cli\"\n );\n }\n if (!projectId) {\n throw new import_token_error.VercelOidcTokenError(\n \"Failed to refresh OIDC token: project id not found\"\n );\n }\n maybeToken = await (0, import_token_util.getVercelOidcToken)(authToken, projectId, teamId);\n if (!maybeToken) {\n throw new import_token_error.VercelOidcTokenError(\"Failed to refresh OIDC token\");\n }\n (0, import_token_util.saveToken)(maybeToken, projectId);\n }\n process.env.VERCEL_OIDC_TOKEN = maybeToken.token;\n return;\n}\n// Annotate the CommonJS export names for ESM import in node:\n0 && (module.exports = {\n refreshToken\n});\n"],"mappings":";;;;;;;;;AAAA;AAAA;AACA,QAAI,YAAY,OAAO;AACvB,QAAI,mBAAmB,OAAO;AAC9B,QAAI,oBAAoB,OAAO;AAC/B,QAAI,eAAe,OAAO,UAAU;AACpC,QAAI,WAAW,CAAC,QAAQ,QAAQ;AAC9B,eAAS,QAAQ;AACf,kBAAU,QAAQ,MAAM,EAAE,KAAK,IAAI,IAAI,GAAG,YAAY,KAAK,CAAC;AAAA,IAChE;AACA,QAAI,cAAc,CAAC,IAAI,MAAM,QAAQ,SAAS;AAC5C,UAAI,QAAQ,OAAO,SAAS,YAAY,OAAO,SAAS,YAAY;AAClE,iBAAS,OAAO,kBAAkB,IAAI;AACpC,cAAI,CAAC,aAAa,KAAK,IAAI,GAAG,KAAK,QAAQ;AACzC,sBAAU,IAAI,KAAK,EAAE,KAAK,MAAM,KAAK,GAAG,GAAG,YAAY,EAAE,OAAO,iBAAiB,MAAM,GAAG,MAAM,KAAK,WAAW,CAAC;AAAA,MACvH;AACA,aAAO;AAAA,IACT;AACA,QAAI,eAAe,CAAC,QAAQ,YAAY,UAAU,CAAC,GAAG,cAAc,EAAE,OAAO,KAAK,CAAC,GAAG,GAAG;AACzF,QAAI,gBAAgB,CAAC;AACrB,aAAS,eAAe;AAAA,MACtB,cAAc,MAAM;AAAA,IACtB,CAAC;AACD,WAAO,UAAU,aAAa,aAAa;AAC3C,QAAI,qBAAqB;AACzB,QAAI,oBAAoB;AACxB,mBAAe,eAAe;AAC5B,YAAM,EAAE,WAAW,OAAO,KAAK,GAAG,kBAAkB,iBAAiB;AACrE,UAAI,cAAc,GAAG,kBAAkB,WAAW,SAAS;AAC3D,UAAI,CAAC,eAAe,GAAG,kBAAkB,YAAY,GAAG,kBAAkB,iBAAiB,WAAW,KAAK,CAAC,GAAG;AAC7G,cAAM,aAAa,GAAG,kBAAkB,mBAAmB;AAC3D,YAAI,CAAC,WAAW;AACd,gBAAM,IAAI,mBAAmB;AAAA,YAC3B;AAAA,UACF;AAAA,QACF;AACA,YAAI,CAAC,WAAW;AACd,gBAAM,IAAI,mBAAmB;AAAA,YAC3B;AAAA,UACF;AAAA,QACF;AACA,qBAAa,OAAO,GAAG,kBAAkB,oBAAoB,WAAW,WAAW,MAAM;AACzF,YAAI,CAAC,YAAY;AACf,gBAAM,IAAI,mBAAmB,qBAAqB,8BAA8B;AAAA,QAClF;AACA,SAAC,GAAG,kBAAkB,WAAW,YAAY,SAAS;AAAA,MACxD;AACA,cAAQ,IAAI,oBAAoB,WAAW;AAC3C;AAAA,IACF;AAAA;AAAA;","names":[]}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"sources":[],"sourcesContent":[],"mappings":"","names":[]}
|
package/package.json
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
"name": "workers-ai-provider",
|
|
3
3
|
"description": "Workers AI Provider for the vercel AI SDK",
|
|
4
4
|
"type": "module",
|
|
5
|
-
"version": "0.
|
|
5
|
+
"version": "2.0.1",
|
|
6
6
|
"main": "dist/index.js",
|
|
7
7
|
"types": "dist/index.d.ts",
|
|
8
8
|
"repository": {
|
|
@@ -31,12 +31,12 @@
|
|
|
31
31
|
"serverless"
|
|
32
32
|
],
|
|
33
33
|
"dependencies": {
|
|
34
|
-
"@ai-sdk/provider": "^
|
|
35
|
-
"@ai-sdk/provider-utils": "^
|
|
34
|
+
"@ai-sdk/provider": "^2.0.0",
|
|
35
|
+
"@ai-sdk/provider-utils": "^3.0.19"
|
|
36
36
|
},
|
|
37
37
|
"devDependencies": {
|
|
38
|
-
"@cloudflare/workers-types": "^4.
|
|
39
|
-
"zod": "^3.25.
|
|
38
|
+
"@cloudflare/workers-types": "^4.20251221.0",
|
|
39
|
+
"zod": "^3.25.76"
|
|
40
40
|
},
|
|
41
41
|
"scripts": {
|
|
42
42
|
"build": "rm -rf dist && tsup src/index.ts --dts --sourcemap --format esm --target es2020",
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import {
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
1
|
+
import type {
|
|
2
|
+
LanguageModelV2,
|
|
3
|
+
LanguageModelV2CallWarning,
|
|
4
|
+
LanguageModelV2StreamPart,
|
|
5
5
|
} from "@ai-sdk/provider";
|
|
6
6
|
|
|
7
7
|
import type { AutoRAGChatSettings } from "./autorag-chat-settings";
|
|
@@ -17,10 +17,14 @@ type AutoRAGChatConfig = {
|
|
|
17
17
|
gateway?: GatewayOptions;
|
|
18
18
|
};
|
|
19
19
|
|
|
20
|
-
export class AutoRAGChatLanguageModel implements
|
|
21
|
-
readonly specificationVersion = "
|
|
20
|
+
export class AutoRAGChatLanguageModel implements LanguageModelV2 {
|
|
21
|
+
readonly specificationVersion = "v2";
|
|
22
22
|
readonly defaultObjectGenerationMode = "json";
|
|
23
23
|
|
|
24
|
+
readonly supportedUrls: Record<string, RegExp[]> | PromiseLike<Record<string, RegExp[]>> = {
|
|
25
|
+
// TODO: I think No Supported URLs?
|
|
26
|
+
};
|
|
27
|
+
|
|
24
28
|
readonly modelId: TextGenerationModels;
|
|
25
29
|
readonly settings: AutoRAGChatSettings;
|
|
26
30
|
|
|
@@ -41,14 +45,14 @@ export class AutoRAGChatLanguageModel implements LanguageModelV1 {
|
|
|
41
45
|
}
|
|
42
46
|
|
|
43
47
|
private getArgs({
|
|
44
|
-
|
|
48
|
+
responseFormat,
|
|
45
49
|
prompt,
|
|
50
|
+
tools,
|
|
51
|
+
toolChoice,
|
|
46
52
|
frequencyPenalty,
|
|
47
53
|
presencePenalty,
|
|
48
|
-
}: Parameters<
|
|
49
|
-
const
|
|
50
|
-
|
|
51
|
-
const warnings: LanguageModelV1CallWarning[] = [];
|
|
54
|
+
}: Parameters<LanguageModelV2["doGenerate"]>[0]) {
|
|
55
|
+
const warnings: LanguageModelV2CallWarning[] = [];
|
|
52
56
|
|
|
53
57
|
if (frequencyPenalty != null) {
|
|
54
58
|
warnings.push({
|
|
@@ -71,20 +75,21 @@ export class AutoRAGChatLanguageModel implements LanguageModelV1 {
|
|
|
71
75
|
model: this.modelId,
|
|
72
76
|
};
|
|
73
77
|
|
|
78
|
+
const type = responseFormat?.type ?? "text";
|
|
74
79
|
switch (type) {
|
|
75
|
-
case "
|
|
80
|
+
case "text": {
|
|
76
81
|
return {
|
|
77
|
-
args: { ...baseArgs, ...prepareToolsAndToolChoice(
|
|
82
|
+
args: { ...baseArgs, ...prepareToolsAndToolChoice(tools, toolChoice) },
|
|
78
83
|
warnings,
|
|
79
84
|
};
|
|
80
85
|
}
|
|
81
86
|
|
|
82
|
-
case "
|
|
87
|
+
case "json": {
|
|
83
88
|
return {
|
|
84
89
|
args: {
|
|
85
90
|
...baseArgs,
|
|
86
91
|
response_format: {
|
|
87
|
-
json_schema:
|
|
92
|
+
json_schema: responseFormat?.type === "json" && responseFormat.schema,
|
|
88
93
|
type: "json_schema",
|
|
89
94
|
},
|
|
90
95
|
tools: undefined,
|
|
@@ -93,25 +98,6 @@ export class AutoRAGChatLanguageModel implements LanguageModelV1 {
|
|
|
93
98
|
};
|
|
94
99
|
}
|
|
95
100
|
|
|
96
|
-
case "object-tool": {
|
|
97
|
-
return {
|
|
98
|
-
args: {
|
|
99
|
-
...baseArgs,
|
|
100
|
-
tool_choice: "any",
|
|
101
|
-
tools: [{ function: mode.tool, type: "function" }],
|
|
102
|
-
},
|
|
103
|
-
warnings,
|
|
104
|
-
};
|
|
105
|
-
}
|
|
106
|
-
|
|
107
|
-
// @ts-expect-error - this is unreachable code
|
|
108
|
-
// TODO: fixme
|
|
109
|
-
case "object-grammar": {
|
|
110
|
-
throw new UnsupportedFunctionalityError({
|
|
111
|
-
functionality: "object-grammar mode",
|
|
112
|
-
});
|
|
113
|
-
}
|
|
114
|
-
|
|
115
101
|
default: {
|
|
116
102
|
const exhaustiveCheck = type satisfies never;
|
|
117
103
|
throw new Error(`Unsupported type: ${exhaustiveCheck}`);
|
|
@@ -120,10 +106,9 @@ export class AutoRAGChatLanguageModel implements LanguageModelV1 {
|
|
|
120
106
|
}
|
|
121
107
|
|
|
122
108
|
async doGenerate(
|
|
123
|
-
options: Parameters<
|
|
124
|
-
): Promise<Awaited<ReturnType<
|
|
125
|
-
const {
|
|
126
|
-
|
|
109
|
+
options: Parameters<LanguageModelV2["doGenerate"]>[0],
|
|
110
|
+
): Promise<Awaited<ReturnType<LanguageModelV2["doGenerate"]>>> {
|
|
111
|
+
const { warnings } = this.getArgs(options);
|
|
127
112
|
const { messages } = convertToWorkersAIChatMessages(options.prompt);
|
|
128
113
|
|
|
129
114
|
const output = await this.config.binding.aiSearch({
|
|
@@ -132,40 +117,77 @@ export class AutoRAGChatLanguageModel implements LanguageModelV1 {
|
|
|
132
117
|
|
|
133
118
|
return {
|
|
134
119
|
finishReason: "stop",
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
120
|
+
|
|
121
|
+
content: [
|
|
122
|
+
...output.data.map(({ file_id, filename, score }) => ({
|
|
123
|
+
type: "source" as const,
|
|
124
|
+
sourceType: "url" as const,
|
|
125
|
+
id: file_id,
|
|
126
|
+
url: filename,
|
|
127
|
+
providerMetadata: {
|
|
128
|
+
attributes: { score },
|
|
129
|
+
},
|
|
130
|
+
})),
|
|
131
|
+
{
|
|
132
|
+
type: "text" as const,
|
|
133
|
+
text: output.response,
|
|
140
134
|
},
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
})), // TODO: mapWorkersAIFinishReason(response.finish_reason),
|
|
144
|
-
text: output.response,
|
|
145
|
-
toolCalls: processToolCalls(output),
|
|
135
|
+
...processToolCalls(output),
|
|
136
|
+
],
|
|
146
137
|
usage: mapWorkersAIUsage(output),
|
|
147
138
|
warnings,
|
|
148
139
|
};
|
|
149
140
|
}
|
|
150
141
|
|
|
151
142
|
async doStream(
|
|
152
|
-
options: Parameters<
|
|
153
|
-
): Promise<Awaited<ReturnType<
|
|
143
|
+
options: Parameters<LanguageModelV2["doStream"]>[0],
|
|
144
|
+
): Promise<Awaited<ReturnType<LanguageModelV2["doStream"]>>> {
|
|
154
145
|
const { args, warnings } = this.getArgs(options);
|
|
155
|
-
|
|
156
146
|
const { messages } = convertToWorkersAIChatMessages(options.prompt);
|
|
157
147
|
|
|
158
148
|
const query = messages.map(({ content, role }) => `${role}: ${content}`).join("\n\n");
|
|
159
149
|
|
|
150
|
+
// Get the underlying streaming response (assume this returns a ReadableStream<LanguageModelV2StreamPart>)
|
|
160
151
|
const response = await this.config.binding.aiSearch({
|
|
161
152
|
query,
|
|
162
153
|
stream: true,
|
|
163
154
|
});
|
|
164
155
|
|
|
156
|
+
// Create a new stream that first emits the stream-start part with warnings,
|
|
157
|
+
// then pipes through the rest of the response stream
|
|
158
|
+
const stream = new ReadableStream<LanguageModelV2StreamPart>({
|
|
159
|
+
start(controller) {
|
|
160
|
+
// Emit the stream-start part with warnings
|
|
161
|
+
controller.enqueue({
|
|
162
|
+
type: "stream-start",
|
|
163
|
+
warnings: warnings as LanguageModelV2CallWarning[],
|
|
164
|
+
});
|
|
165
|
+
|
|
166
|
+
// Pipe the rest of the response stream
|
|
167
|
+
const reader = getMappedStream(response).getReader();
|
|
168
|
+
|
|
169
|
+
function push() {
|
|
170
|
+
reader.read().then(({ done, value }) => {
|
|
171
|
+
if (done) {
|
|
172
|
+
controller.close();
|
|
173
|
+
return;
|
|
174
|
+
}
|
|
175
|
+
controller.enqueue(value);
|
|
176
|
+
push();
|
|
177
|
+
});
|
|
178
|
+
}
|
|
179
|
+
push();
|
|
180
|
+
},
|
|
181
|
+
});
|
|
182
|
+
|
|
165
183
|
return {
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
184
|
+
stream,
|
|
185
|
+
request: {
|
|
186
|
+
body: {
|
|
187
|
+
rawPrompt: args.messages,
|
|
188
|
+
rawSettings: args,
|
|
189
|
+
},
|
|
190
|
+
},
|
|
169
191
|
};
|
|
170
192
|
}
|
|
171
193
|
}
|
|
@@ -1,19 +1,19 @@
|
|
|
1
|
-
import type {
|
|
1
|
+
import type { LanguageModelV2Prompt, SharedV2ProviderMetadata } from "@ai-sdk/provider";
|
|
2
2
|
import type { WorkersAIChatPrompt } from "./workersai-chat-prompt";
|
|
3
3
|
|
|
4
|
-
export function convertToWorkersAIChatMessages(prompt:
|
|
4
|
+
export function convertToWorkersAIChatMessages(prompt: LanguageModelV2Prompt): {
|
|
5
5
|
messages: WorkersAIChatPrompt;
|
|
6
6
|
images: {
|
|
7
7
|
mimeType: string | undefined;
|
|
8
8
|
image: Uint8Array;
|
|
9
|
-
|
|
9
|
+
providerOptions: SharedV2ProviderMetadata | undefined;
|
|
10
10
|
}[];
|
|
11
11
|
} {
|
|
12
12
|
const messages: WorkersAIChatPrompt = [];
|
|
13
13
|
const images: {
|
|
14
14
|
mimeType: string | undefined;
|
|
15
15
|
image: Uint8Array;
|
|
16
|
-
|
|
16
|
+
providerOptions: SharedV2ProviderMetadata | undefined;
|
|
17
17
|
}[] = [];
|
|
18
18
|
|
|
19
19
|
for (const { role, content } of prompt) {
|
|
@@ -31,20 +31,22 @@ export function convertToWorkersAIChatMessages(prompt: LanguageModelV1Prompt): {
|
|
|
31
31
|
case "text": {
|
|
32
32
|
return part.text;
|
|
33
33
|
}
|
|
34
|
-
case "
|
|
34
|
+
case "file": {
|
|
35
35
|
// Extract image from this part
|
|
36
|
-
if (part.
|
|
36
|
+
if (part.data instanceof Uint8Array) {
|
|
37
37
|
// Store the image data directly as Uint8Array
|
|
38
38
|
// For Llama 3.2 Vision model, which needs array of integers
|
|
39
39
|
images.push({
|
|
40
|
-
image: part.
|
|
41
|
-
mimeType: part.
|
|
42
|
-
|
|
40
|
+
image: part.data,
|
|
41
|
+
mimeType: part.mediaType,
|
|
42
|
+
providerOptions: part.providerOptions,
|
|
43
43
|
});
|
|
44
44
|
}
|
|
45
45
|
return ""; // No text for the image part
|
|
46
46
|
}
|
|
47
47
|
}
|
|
48
|
+
|
|
49
|
+
return undefined;
|
|
48
50
|
})
|
|
49
51
|
.join("\n"),
|
|
50
52
|
role: "user",
|
|
@@ -75,12 +77,12 @@ export function convertToWorkersAIChatMessages(prompt: LanguageModelV1Prompt): {
|
|
|
75
77
|
case "tool-call": {
|
|
76
78
|
text = JSON.stringify({
|
|
77
79
|
name: part.toolName,
|
|
78
|
-
parameters: part.
|
|
80
|
+
parameters: part.input,
|
|
79
81
|
});
|
|
80
82
|
|
|
81
83
|
toolCalls.push({
|
|
82
84
|
function: {
|
|
83
|
-
arguments: JSON.stringify(part.
|
|
85
|
+
arguments: JSON.stringify(part.input),
|
|
84
86
|
name: part.toolName,
|
|
85
87
|
},
|
|
86
88
|
id: part.toolCallId,
|
|
@@ -114,7 +116,7 @@ export function convertToWorkersAIChatMessages(prompt: LanguageModelV1Prompt): {
|
|
|
114
116
|
case "tool": {
|
|
115
117
|
for (const [index, toolResponse] of content.entries()) {
|
|
116
118
|
messages.push({
|
|
117
|
-
content: JSON.stringify(toolResponse.
|
|
119
|
+
content: JSON.stringify(toolResponse.output),
|
|
118
120
|
name: toolResponse.toolName,
|
|
119
121
|
tool_call_id: `functions.${toolResponse.toolName}:${index}`,
|
|
120
122
|
role: "tool",
|
package/src/index.ts
CHANGED
|
@@ -160,11 +160,15 @@ export function createAutoRAG(options: AutoRAGSettings): AutoRAGProvider {
|
|
|
160
160
|
const binding = options.binding;
|
|
161
161
|
|
|
162
162
|
const createChatModel = (settings: AutoRAGChatSettings = {}) =>
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
163
|
+
new AutoRAGChatLanguageModel(
|
|
164
|
+
// @ts-expect-error Needs fix from @cloudflare/workers-types for custom types
|
|
165
|
+
"@cf/meta/llama-3.3-70b-instruct-fp8-fast",
|
|
166
|
+
settings,
|
|
167
|
+
{
|
|
168
|
+
binding,
|
|
169
|
+
provider: "autorag.chat",
|
|
170
|
+
},
|
|
171
|
+
);
|
|
168
172
|
|
|
169
173
|
const provider = (settings?: AutoRAGChatSettings) => {
|
|
170
174
|
if (new.target) {
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import type {
|
|
1
|
+
import type { LanguageModelV2FinishReason } from "@ai-sdk/provider";
|
|
2
2
|
|
|
3
|
-
export function mapWorkersAIFinishReason(finishReasonOrResponse: any):
|
|
3
|
+
export function mapWorkersAIFinishReason(finishReasonOrResponse: any): LanguageModelV2FinishReason {
|
|
4
4
|
let finishReason: string | null | undefined;
|
|
5
5
|
|
|
6
6
|
// If it's a string/null/undefined, use it directly (original behavior)
|
|
@@ -9,7 +9,8 @@ export function mapWorkersAIUsage(output: AiTextGenerationOutput | AiTextToImage
|
|
|
9
9
|
};
|
|
10
10
|
|
|
11
11
|
return {
|
|
12
|
-
|
|
13
|
-
|
|
12
|
+
outputTokens: usage.completion_tokens,
|
|
13
|
+
inputTokens: usage.prompt_tokens,
|
|
14
|
+
totalTokens: usage.prompt_tokens + usage.completion_tokens,
|
|
14
15
|
};
|
|
15
16
|
}
|
package/src/streaming.ts
CHANGED
|
@@ -1,14 +1,19 @@
|
|
|
1
|
-
import type {
|
|
1
|
+
import type { LanguageModelV2StreamPart } from "@ai-sdk/provider";
|
|
2
|
+
import { generateId } from "ai";
|
|
2
3
|
import { events } from "fetch-event-stream";
|
|
3
4
|
import { mapWorkersAIUsage } from "./map-workersai-usage";
|
|
4
5
|
import { processPartialToolCalls } from "./utils";
|
|
5
6
|
|
|
6
7
|
export function getMappedStream(response: Response) {
|
|
7
8
|
const chunkEvent = events(response);
|
|
8
|
-
let usage = {
|
|
9
|
+
let usage = { outputTokens: 0, inputTokens: 0, totalTokens: 0 };
|
|
9
10
|
const partialToolCalls: any[] = [];
|
|
10
11
|
|
|
11
|
-
|
|
12
|
+
// Track start/delta/end IDs per v5 streaming protocol
|
|
13
|
+
let textId: string | null = null;
|
|
14
|
+
let reasoningId: string | null = null;
|
|
15
|
+
|
|
16
|
+
return new ReadableStream<LanguageModelV2StreamPart>({
|
|
12
17
|
async start(controller) {
|
|
13
18
|
for await (const event of chunkEvent) {
|
|
14
19
|
if (!event.data) {
|
|
@@ -24,33 +29,66 @@ export function getMappedStream(response: Response) {
|
|
|
24
29
|
if (chunk.tool_calls) {
|
|
25
30
|
partialToolCalls.push(...chunk.tool_calls);
|
|
26
31
|
}
|
|
27
|
-
|
|
32
|
+
|
|
33
|
+
// Handle top-level response text
|
|
34
|
+
if (chunk.response?.length) {
|
|
35
|
+
if (!textId) {
|
|
36
|
+
textId = generateId();
|
|
37
|
+
controller.enqueue({ type: "text-start", id: textId });
|
|
38
|
+
}
|
|
28
39
|
controller.enqueue({
|
|
29
|
-
textDelta: chunk.response,
|
|
30
40
|
type: "text-delta",
|
|
41
|
+
id: textId,
|
|
42
|
+
delta: chunk.response,
|
|
31
43
|
});
|
|
32
|
-
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
// Handle reasoning content
|
|
47
|
+
const reasoningDelta = chunk?.choices?.[0]?.delta?.reasoning_content;
|
|
48
|
+
if (reasoningDelta?.length) {
|
|
49
|
+
if (!reasoningId) {
|
|
50
|
+
reasoningId = generateId();
|
|
51
|
+
controller.enqueue({ type: "reasoning-start", id: reasoningId });
|
|
52
|
+
}
|
|
33
53
|
controller.enqueue({
|
|
34
|
-
type: "reasoning",
|
|
35
|
-
|
|
54
|
+
type: "reasoning-delta",
|
|
55
|
+
id: reasoningId,
|
|
56
|
+
delta: reasoningDelta,
|
|
36
57
|
});
|
|
37
|
-
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
// Handle text content from choices
|
|
61
|
+
const textDelta = chunk?.choices?.[0]?.delta?.content;
|
|
62
|
+
if (textDelta?.length) {
|
|
63
|
+
if (!textId) {
|
|
64
|
+
textId = generateId();
|
|
65
|
+
controller.enqueue({ type: "text-start", id: textId });
|
|
66
|
+
}
|
|
38
67
|
controller.enqueue({
|
|
39
68
|
type: "text-delta",
|
|
40
|
-
|
|
69
|
+
id: textId,
|
|
70
|
+
delta: textDelta,
|
|
41
71
|
});
|
|
72
|
+
}
|
|
42
73
|
}
|
|
43
74
|
|
|
44
75
|
if (partialToolCalls.length > 0) {
|
|
45
76
|
const toolCalls = processPartialToolCalls(partialToolCalls);
|
|
46
|
-
toolCalls.
|
|
47
|
-
controller.enqueue(
|
|
48
|
-
type: "tool-call",
|
|
49
|
-
...toolCall,
|
|
50
|
-
});
|
|
77
|
+
toolCalls.forEach((toolCall) => {
|
|
78
|
+
controller.enqueue(toolCall);
|
|
51
79
|
});
|
|
52
80
|
}
|
|
53
81
|
|
|
82
|
+
// Close any open blocks
|
|
83
|
+
if (reasoningId) {
|
|
84
|
+
controller.enqueue({ type: "reasoning-end", id: reasoningId });
|
|
85
|
+
reasoningId = null;
|
|
86
|
+
}
|
|
87
|
+
if (textId) {
|
|
88
|
+
controller.enqueue({ type: "text-end", id: textId });
|
|
89
|
+
textId = null;
|
|
90
|
+
}
|
|
91
|
+
|
|
54
92
|
controller.enqueue({
|
|
55
93
|
finishReason: "stop",
|
|
56
94
|
type: "finish",
|
package/src/utils.ts
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
import type {
|
|
1
|
+
import type { LanguageModelV2, LanguageModelV2ToolCall } from "@ai-sdk/provider";
|
|
2
|
+
import { generateId } from "ai";
|
|
2
3
|
|
|
3
4
|
/**
|
|
4
5
|
* General AI run interface with overloads to handle distinct return types.
|
|
@@ -127,30 +128,22 @@ export function createRun(config: CreateRunConfig): AiRun {
|
|
|
127
128
|
}
|
|
128
129
|
|
|
129
130
|
export function prepareToolsAndToolChoice(
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
},
|
|
131
|
+
tools: Parameters<LanguageModelV2["doGenerate"]>[0]["tools"],
|
|
132
|
+
toolChoice: Parameters<LanguageModelV2["doGenerate"]>[0]["toolChoice"],
|
|
133
133
|
) {
|
|
134
|
-
// when the tools array is empty, change it to undefined to prevent errors:
|
|
135
|
-
const tools = mode.tools?.length ? mode.tools : undefined;
|
|
136
|
-
|
|
137
134
|
if (tools == null) {
|
|
138
135
|
return { tool_choice: undefined, tools: undefined };
|
|
139
136
|
}
|
|
140
137
|
|
|
141
138
|
const mappedTools = tools.map((tool) => ({
|
|
142
139
|
function: {
|
|
143
|
-
|
|
144
|
-
description: tool.description,
|
|
140
|
+
description: tool.type === "function" && tool.description,
|
|
145
141
|
name: tool.name,
|
|
146
|
-
|
|
147
|
-
parameters: tool.parameters,
|
|
142
|
+
parameters: tool.type === "function" && tool.inputSchema,
|
|
148
143
|
},
|
|
149
144
|
type: "function",
|
|
150
145
|
}));
|
|
151
146
|
|
|
152
|
-
const toolChoice = mode.toolChoice;
|
|
153
|
-
|
|
154
147
|
if (toolChoice == null) {
|
|
155
148
|
return { tool_choice: undefined, tools: mappedTools };
|
|
156
149
|
}
|
|
@@ -220,31 +213,31 @@ function mergePartialToolCalls(partialCalls: any[]) {
|
|
|
220
213
|
return Object.values(mergedCallsByIndex);
|
|
221
214
|
}
|
|
222
215
|
|
|
223
|
-
function processToolCall(toolCall: any):
|
|
216
|
+
function processToolCall(toolCall: any): LanguageModelV2ToolCall {
|
|
224
217
|
// Check for OpenAI format tool calls first
|
|
225
218
|
if (toolCall.function && toolCall.id) {
|
|
226
219
|
return {
|
|
227
|
-
|
|
220
|
+
input:
|
|
228
221
|
typeof toolCall.function.arguments === "string"
|
|
229
222
|
? toolCall.function.arguments
|
|
230
223
|
: JSON.stringify(toolCall.function.arguments || {}),
|
|
231
|
-
toolCallId: toolCall.id,
|
|
232
|
-
|
|
224
|
+
toolCallId: toolCall.id || generateId(),
|
|
225
|
+
type: "tool-call",
|
|
233
226
|
toolName: toolCall.function.name,
|
|
234
227
|
};
|
|
235
228
|
}
|
|
236
229
|
return {
|
|
237
|
-
|
|
230
|
+
input:
|
|
238
231
|
typeof toolCall.arguments === "string"
|
|
239
232
|
? toolCall.arguments
|
|
240
233
|
: JSON.stringify(toolCall.arguments || {}),
|
|
241
|
-
toolCallId: toolCall.
|
|
242
|
-
|
|
234
|
+
toolCallId: toolCall.id || generateId(),
|
|
235
|
+
type: "tool-call",
|
|
243
236
|
toolName: toolCall.name,
|
|
244
237
|
};
|
|
245
238
|
}
|
|
246
239
|
|
|
247
|
-
export function processToolCalls(output: any):
|
|
240
|
+
export function processToolCalls(output: any): LanguageModelV2ToolCall[] {
|
|
248
241
|
if (output.tool_calls && Array.isArray(output.tool_calls)) {
|
|
249
242
|
return output.tool_calls.map((toolCall: any) => {
|
|
250
243
|
const processedToolCall = processToolCall(toolCall);
|