workers-ai-provider 0.4.0 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.ts CHANGED
@@ -4,17 +4,12 @@ type StringLike = string | {
4
4
  toString(): string;
5
5
  };
6
6
 
7
- type WorkersAIChatSettings = {
7
+ type AutoRAGChatSettings = {
8
8
  /**
9
9
  * Whether to inject a safety prompt before all conversations.
10
10
  * Defaults to `false`.
11
11
  */
12
12
  safePrompt?: boolean;
13
- /**
14
- * Optionally set Cloudflare AI Gateway options.
15
- * @deprecated
16
- */
17
- gateway?: GatewayOptions;
18
13
  } & {
19
14
  /**
20
15
  * Passthrough settings that are provided directly to the run function.
@@ -31,6 +26,42 @@ type value2key<T, V> = {
31
26
  [K in keyof T]: T[K] extends V ? K : never;
32
27
  }[keyof T];
33
28
 
29
+ type AutoRAGChatConfig = {
30
+ provider: string;
31
+ binding: AutoRAG;
32
+ gateway?: GatewayOptions;
33
+ };
34
+ declare class AutoRAGChatLanguageModel implements LanguageModelV1 {
35
+ readonly specificationVersion = "v1";
36
+ readonly defaultObjectGenerationMode = "json";
37
+ readonly modelId: TextGenerationModels;
38
+ readonly settings: AutoRAGChatSettings;
39
+ private readonly config;
40
+ constructor(modelId: TextGenerationModels, settings: AutoRAGChatSettings, config: AutoRAGChatConfig);
41
+ get provider(): string;
42
+ private getArgs;
43
+ doGenerate(options: Parameters<LanguageModelV1["doGenerate"]>[0]): Promise<Awaited<ReturnType<LanguageModelV1["doGenerate"]>>>;
44
+ doStream(options: Parameters<LanguageModelV1["doStream"]>[0]): Promise<Awaited<ReturnType<LanguageModelV1["doStream"]>>>;
45
+ }
46
+
47
+ type WorkersAIChatSettings = {
48
+ /**
49
+ * Whether to inject a safety prompt before all conversations.
50
+ * Defaults to `false`.
51
+ */
52
+ safePrompt?: boolean;
53
+ /**
54
+ * Optionally set Cloudflare AI Gateway options.
55
+ * @deprecated
56
+ */
57
+ gateway?: GatewayOptions;
58
+ } & {
59
+ /**
60
+ * Passthrough settings that are provided directly to the run function.
61
+ */
62
+ [key: string]: StringLike;
63
+ };
64
+
34
65
  type WorkersAIChatConfig = {
35
66
  provider: string;
36
67
  binding: Ai;
@@ -111,5 +142,19 @@ interface WorkersAI {
111
142
  * Create a Workers AI provider instance.
112
143
  */
113
144
  declare function createWorkersAI(options: WorkersAISettings): WorkersAI;
145
+ type AutoRAGSettings = {
146
+ binding: AutoRAG;
147
+ };
148
+ interface AutoRAGProvider {
149
+ (options?: AutoRAGChatSettings): AutoRAGChatLanguageModel;
150
+ /**
151
+ * Creates a model for text generation.
152
+ **/
153
+ chat(settings?: AutoRAGChatSettings): AutoRAGChatLanguageModel;
154
+ }
155
+ /**
156
+ * Create a Workers AI provider instance.
157
+ */
158
+ declare function createAutoRAG(options: AutoRAGSettings): AutoRAGProvider;
114
159
 
115
- export { type WorkersAI, type WorkersAISettings, createWorkersAI };
160
+ export { type AutoRAGProvider, type AutoRAGSettings, type WorkersAI, type WorkersAISettings, createAutoRAG, createWorkersAI };
package/dist/index.js CHANGED
@@ -9,51 +9,7 @@ var __privateGet = (obj, member, getter) => (__accessCheck(obj, member, "read fr
9
9
  var __privateAdd = (obj, member, value) => member.has(obj) ? __typeError("Cannot add the same private member more than once") : member instanceof WeakSet ? member.add(obj) : member.set(obj, value);
10
10
  var __privateSet = (obj, member, value, setter) => (__accessCheck(obj, member, "write to private field"), setter ? setter.call(obj, value) : member.set(obj, value), value);
11
11
 
12
- // src/utils.ts
13
- function createRun(config) {
14
- const { accountId, apiKey } = config;
15
- return async function run(model, inputs, options) {
16
- const { gateway, prefix, extraHeaders, returnRawResponse, ...passthroughOptions } = options || {};
17
- const urlParams = new URLSearchParams();
18
- for (const [key, value] of Object.entries(passthroughOptions)) {
19
- try {
20
- const valueStr = value.toString();
21
- if (!valueStr) {
22
- continue;
23
- }
24
- urlParams.append(key, valueStr);
25
- } catch (error) {
26
- throw new Error(
27
- `Value for option '${key}' is not able to be coerced into a string.`
28
- );
29
- }
30
- }
31
- const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}${urlParams ? `?${urlParams}` : ""}`;
32
- const headers = {
33
- "Content-Type": "application/json",
34
- Authorization: `Bearer ${apiKey}`
35
- };
36
- const body = JSON.stringify(inputs);
37
- const response = await fetch(url, {
38
- method: "POST",
39
- headers,
40
- body
41
- });
42
- if (returnRawResponse) {
43
- return response;
44
- }
45
- if (inputs.stream === true) {
46
- if (response.body) {
47
- return response.body;
48
- }
49
- throw new Error("No readable body available for streaming.");
50
- }
51
- const data = await response.json();
52
- return data.result;
53
- };
54
- }
55
-
56
- // src/workersai-chat-language-model.ts
12
+ // src/autorag-chat-language-model.ts
57
13
  import {
58
14
  UnsupportedFunctionalityError
59
15
  } from "@ai-sdk/provider";
@@ -157,6 +113,18 @@ function convertToWorkersAIChatMessages(prompt) {
157
113
  return { messages, images };
158
114
  }
159
115
 
116
+ // src/map-workersai-usage.ts
117
+ function mapWorkersAIUsage(output) {
118
+ const usage = output.usage ?? {
119
+ prompt_tokens: 0,
120
+ completion_tokens: 0
121
+ };
122
+ return {
123
+ promptTokens: usage.prompt_tokens,
124
+ completionTokens: usage.completion_tokens
125
+ };
126
+ }
127
+
160
128
  // ../../node_modules/.pnpm/fetch-event-stream@0.1.5/node_modules/fetch-event-stream/esm/deps/jsr.io/@std/streams/0.221.0/text_line_stream.js
161
129
  var _currentLine;
162
130
  var TextLineStream = class extends TransformStream {
@@ -250,19 +218,274 @@ async function* events(res, signal) {
250
218
  }
251
219
  }
252
220
 
253
- // src/map-workersai-usage.ts
254
- function mapWorkersAIUsage(output) {
255
- const usage = output.usage ?? {
256
- prompt_tokens: 0,
257
- completion_tokens: 0
258
- };
259
- return {
260
- promptTokens: usage.prompt_tokens,
261
- completionTokens: usage.completion_tokens
221
+ // src/streaming.ts
222
+ function getMappedStream(response) {
223
+ const chunkEvent = events(response);
224
+ let usage = { promptTokens: 0, completionTokens: 0 };
225
+ return new ReadableStream({
226
+ async start(controller) {
227
+ for await (const event of chunkEvent) {
228
+ if (!event.data) {
229
+ continue;
230
+ }
231
+ if (event.data === "[DONE]") {
232
+ break;
233
+ }
234
+ const chunk = JSON.parse(event.data);
235
+ if (chunk.usage) {
236
+ usage = mapWorkersAIUsage(chunk);
237
+ }
238
+ chunk.response?.length && controller.enqueue({
239
+ type: "text-delta",
240
+ textDelta: chunk.response
241
+ });
242
+ }
243
+ controller.enqueue({
244
+ type: "finish",
245
+ finishReason: "stop",
246
+ usage
247
+ });
248
+ controller.close();
249
+ }
250
+ });
251
+ }
252
+
253
+ // src/utils.ts
254
+ function createRun(config) {
255
+ const { accountId, apiKey } = config;
256
+ return async function run(model, inputs, options) {
257
+ const { gateway, prefix, extraHeaders, returnRawResponse, ...passthroughOptions } = options || {};
258
+ const urlParams = new URLSearchParams();
259
+ for (const [key, value] of Object.entries(passthroughOptions)) {
260
+ try {
261
+ const valueStr = value.toString();
262
+ if (!valueStr) {
263
+ continue;
264
+ }
265
+ urlParams.append(key, valueStr);
266
+ } catch (error) {
267
+ throw new Error(
268
+ `Value for option '${key}' is not able to be coerced into a string.`
269
+ );
270
+ }
271
+ }
272
+ const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}${urlParams ? `?${urlParams}` : ""}`;
273
+ const headers = {
274
+ "Content-Type": "application/json",
275
+ Authorization: `Bearer ${apiKey}`
276
+ };
277
+ const body = JSON.stringify(inputs);
278
+ const response = await fetch(url, {
279
+ method: "POST",
280
+ headers,
281
+ body
282
+ });
283
+ if (returnRawResponse) {
284
+ return response;
285
+ }
286
+ if (inputs.stream === true) {
287
+ if (response.body) {
288
+ return response.body;
289
+ }
290
+ throw new Error("No readable body available for streaming.");
291
+ }
292
+ const data = await response.json();
293
+ return data.result;
262
294
  };
263
295
  }
296
+ function prepareToolsAndToolChoice(mode) {
297
+ const tools = mode.tools?.length ? mode.tools : void 0;
298
+ if (tools == null) {
299
+ return { tools: void 0, tool_choice: void 0 };
300
+ }
301
+ const mappedTools = tools.map((tool) => ({
302
+ type: "function",
303
+ function: {
304
+ name: tool.name,
305
+ // @ts-expect-error - description is not a property of tool
306
+ description: tool.description,
307
+ // @ts-expect-error - parameters is not a property of tool
308
+ parameters: tool.parameters
309
+ }
310
+ }));
311
+ const toolChoice = mode.toolChoice;
312
+ if (toolChoice == null) {
313
+ return { tools: mappedTools, tool_choice: void 0 };
314
+ }
315
+ const type = toolChoice.type;
316
+ switch (type) {
317
+ case "auto":
318
+ return { tools: mappedTools, tool_choice: type };
319
+ case "none":
320
+ return { tools: mappedTools, tool_choice: type };
321
+ case "required":
322
+ return { tools: mappedTools, tool_choice: "any" };
323
+ // workersAI does not support tool mode directly,
324
+ // so we filter the tools and force the tool choice through 'any'
325
+ case "tool":
326
+ return {
327
+ tools: mappedTools.filter((tool) => tool.function.name === toolChoice.toolName),
328
+ tool_choice: "any"
329
+ };
330
+ default: {
331
+ const exhaustiveCheck = type;
332
+ throw new Error(`Unsupported tool choice type: ${exhaustiveCheck}`);
333
+ }
334
+ }
335
+ }
336
+ function lastMessageWasUser(messages) {
337
+ return messages.length > 0 && messages[messages.length - 1].role === "user";
338
+ }
339
+ function processToolCalls(output) {
340
+ if (output.tool_calls && Array.isArray(output.tool_calls)) {
341
+ return output.tool_calls.map((toolCall) => {
342
+ if (toolCall.function && toolCall.id) {
343
+ return {
344
+ toolCallType: "function",
345
+ toolCallId: toolCall.id,
346
+ toolName: toolCall.function.name,
347
+ args: typeof toolCall.function.arguments === "string" ? toolCall.function.arguments : JSON.stringify(toolCall.function.arguments || {})
348
+ };
349
+ }
350
+ return {
351
+ toolCallType: "function",
352
+ toolCallId: toolCall.name,
353
+ toolName: toolCall.name,
354
+ args: typeof toolCall.arguments === "string" ? toolCall.arguments : JSON.stringify(toolCall.arguments || {})
355
+ };
356
+ });
357
+ }
358
+ return [];
359
+ }
360
+
361
+ // src/autorag-chat-language-model.ts
362
+ var AutoRAGChatLanguageModel = class {
363
+ constructor(modelId, settings, config) {
364
+ __publicField(this, "specificationVersion", "v1");
365
+ __publicField(this, "defaultObjectGenerationMode", "json");
366
+ __publicField(this, "modelId");
367
+ __publicField(this, "settings");
368
+ __publicField(this, "config");
369
+ this.modelId = modelId;
370
+ this.settings = settings;
371
+ this.config = config;
372
+ }
373
+ get provider() {
374
+ return this.config.provider;
375
+ }
376
+ getArgs({
377
+ mode,
378
+ prompt,
379
+ frequencyPenalty,
380
+ presencePenalty
381
+ }) {
382
+ const type = mode.type;
383
+ const warnings = [];
384
+ if (frequencyPenalty != null) {
385
+ warnings.push({
386
+ type: "unsupported-setting",
387
+ setting: "frequencyPenalty"
388
+ });
389
+ }
390
+ if (presencePenalty != null) {
391
+ warnings.push({
392
+ type: "unsupported-setting",
393
+ setting: "presencePenalty"
394
+ });
395
+ }
396
+ const baseArgs = {
397
+ // model id:
398
+ model: this.modelId,
399
+ // messages:
400
+ messages: convertToWorkersAIChatMessages(prompt)
401
+ };
402
+ switch (type) {
403
+ case "regular": {
404
+ return {
405
+ args: { ...baseArgs, ...prepareToolsAndToolChoice(mode) },
406
+ warnings
407
+ };
408
+ }
409
+ case "object-json": {
410
+ return {
411
+ args: {
412
+ ...baseArgs,
413
+ response_format: {
414
+ type: "json_schema",
415
+ json_schema: mode.schema
416
+ },
417
+ tools: void 0
418
+ },
419
+ warnings
420
+ };
421
+ }
422
+ case "object-tool": {
423
+ return {
424
+ args: {
425
+ ...baseArgs,
426
+ tool_choice: "any",
427
+ tools: [{ type: "function", function: mode.tool }]
428
+ },
429
+ warnings
430
+ };
431
+ }
432
+ // @ts-expect-error - this is unreachable code
433
+ // TODO: fixme
434
+ case "object-grammar": {
435
+ throw new UnsupportedFunctionalityError({
436
+ functionality: "object-grammar mode"
437
+ });
438
+ }
439
+ default: {
440
+ const exhaustiveCheck = type;
441
+ throw new Error(`Unsupported type: ${exhaustiveCheck}`);
442
+ }
443
+ }
444
+ }
445
+ async doGenerate(options) {
446
+ const { args, warnings } = this.getArgs(options);
447
+ const { messages } = convertToWorkersAIChatMessages(options.prompt);
448
+ const output = await this.config.binding.aiSearch({
449
+ query: messages.map(({ content, role }) => `${role}: ${content}`).join("\n\n")
450
+ });
451
+ return {
452
+ text: output.response,
453
+ toolCalls: processToolCalls(output),
454
+ finishReason: "stop",
455
+ // TODO: mapWorkersAIFinishReason(response.finish_reason),
456
+ rawCall: { rawPrompt: args.messages, rawSettings: args },
457
+ usage: mapWorkersAIUsage(output),
458
+ warnings,
459
+ sources: output.data.map(({ file_id, filename, score }) => ({
460
+ id: file_id,
461
+ sourceType: "url",
462
+ url: filename,
463
+ providerMetadata: {
464
+ attributes: { score }
465
+ }
466
+ }))
467
+ };
468
+ }
469
+ async doStream(options) {
470
+ const { args, warnings } = this.getArgs(options);
471
+ const { messages } = convertToWorkersAIChatMessages(options.prompt);
472
+ const query = messages.map(({ content, role }) => `${role}: ${content}`).join("\n\n");
473
+ const response = await this.config.binding.aiSearch({
474
+ query,
475
+ stream: true
476
+ });
477
+ return {
478
+ stream: getMappedStream(response),
479
+ rawCall: { rawPrompt: args.messages, rawSettings: args },
480
+ warnings
481
+ };
482
+ }
483
+ };
264
484
 
265
485
  // src/workersai-chat-language-model.ts
486
+ import {
487
+ UnsupportedFunctionalityError as UnsupportedFunctionalityError2
488
+ } from "@ai-sdk/provider";
266
489
  var WorkersAIChatLanguageModel = class {
267
490
  constructor(modelId, settings, config) {
268
491
  __publicField(this, "specificationVersion", "v1");
@@ -344,7 +567,7 @@ var WorkersAIChatLanguageModel = class {
344
567
  // @ts-expect-error - this is unreachable code
345
568
  // TODO: fixme
346
569
  case "object-grammar": {
347
- throw new UnsupportedFunctionalityError({
570
+ throw new UnsupportedFunctionalityError2({
348
571
  functionality: "object-grammar mode"
349
572
  });
350
573
  }
@@ -357,9 +580,7 @@ var WorkersAIChatLanguageModel = class {
357
580
  async doGenerate(options) {
358
581
  const { args, warnings } = this.getArgs(options);
359
582
  const { gateway, safePrompt, ...passthroughOptions } = this.settings;
360
- const { messages, images } = convertToWorkersAIChatMessages(
361
- options.prompt
362
- );
583
+ const { messages, images } = convertToWorkersAIChatMessages(options.prompt);
363
584
  if (images.length !== 0 && images.length !== 1) {
364
585
  throw new Error("Multiple images are not yet supported as input");
365
586
  }
@@ -395,9 +616,7 @@ var WorkersAIChatLanguageModel = class {
395
616
  }
396
617
  async doStream(options) {
397
618
  const { args, warnings } = this.getArgs(options);
398
- const { messages, images } = convertToWorkersAIChatMessages(
399
- options.prompt
400
- );
619
+ const { messages, images } = convertToWorkersAIChatMessages(options.prompt);
401
620
  if (args.tools?.length && lastMessageWasUser(messages)) {
402
621
  const response2 = await this.doGenerate(options);
403
622
  if (response2 instanceof ReadableStream) {
@@ -457,106 +676,13 @@ var WorkersAIChatLanguageModel = class {
457
676
  if (!(response instanceof ReadableStream)) {
458
677
  throw new Error("This shouldn't happen");
459
678
  }
460
- const chunkEvent = events(new Response(response));
461
- let usage = { promptTokens: 0, completionTokens: 0 };
462
679
  return {
463
- stream: new ReadableStream({
464
- async start(controller) {
465
- for await (const event of chunkEvent) {
466
- if (!event.data) {
467
- continue;
468
- }
469
- if (event.data === "[DONE]") {
470
- break;
471
- }
472
- const chunk = JSON.parse(event.data);
473
- if (chunk.usage) {
474
- usage = mapWorkersAIUsage(chunk);
475
- }
476
- chunk.response?.length && controller.enqueue({
477
- type: "text-delta",
478
- textDelta: chunk.response
479
- });
480
- }
481
- controller.enqueue({
482
- type: "finish",
483
- finishReason: "stop",
484
- usage
485
- });
486
- controller.close();
487
- }
488
- }),
680
+ stream: getMappedStream(new Response(response)),
489
681
  rawCall: { rawPrompt: messages, rawSettings: args },
490
682
  warnings
491
683
  };
492
684
  }
493
685
  };
494
- function processToolCalls(output) {
495
- if (output.tool_calls && Array.isArray(output.tool_calls)) {
496
- return output.tool_calls.map((toolCall) => {
497
- if (toolCall.function && toolCall.id) {
498
- return {
499
- toolCallType: "function",
500
- toolCallId: toolCall.id,
501
- toolName: toolCall.function.name,
502
- args: typeof toolCall.function.arguments === "string" ? toolCall.function.arguments : JSON.stringify(toolCall.function.arguments || {})
503
- };
504
- }
505
- return {
506
- toolCallType: "function",
507
- toolCallId: toolCall.name,
508
- toolName: toolCall.name,
509
- args: typeof toolCall.arguments === "string" ? toolCall.arguments : JSON.stringify(toolCall.arguments || {})
510
- };
511
- });
512
- }
513
- return [];
514
- }
515
- function prepareToolsAndToolChoice(mode) {
516
- const tools = mode.tools?.length ? mode.tools : void 0;
517
- if (tools == null) {
518
- return { tools: void 0, tool_choice: void 0 };
519
- }
520
- const mappedTools = tools.map((tool) => ({
521
- type: "function",
522
- function: {
523
- name: tool.name,
524
- // @ts-expect-error - description is not a property of tool
525
- description: tool.description,
526
- // @ts-expect-error - parameters is not a property of tool
527
- parameters: tool.parameters
528
- }
529
- }));
530
- const toolChoice = mode.toolChoice;
531
- if (toolChoice == null) {
532
- return { tools: mappedTools, tool_choice: void 0 };
533
- }
534
- const type = toolChoice.type;
535
- switch (type) {
536
- case "auto":
537
- return { tools: mappedTools, tool_choice: type };
538
- case "none":
539
- return { tools: mappedTools, tool_choice: type };
540
- case "required":
541
- return { tools: mappedTools, tool_choice: "any" };
542
- // workersAI does not support tool mode directly,
543
- // so we filter the tools and force the tool choice through 'any'
544
- case "tool":
545
- return {
546
- tools: mappedTools.filter(
547
- (tool) => tool.function.name === toolChoice.toolName
548
- ),
549
- tool_choice: "any"
550
- };
551
- default: {
552
- const exhaustiveCheck = type;
553
- throw new Error(`Unsupported tool choice type: ${exhaustiveCheck}`);
554
- }
555
- }
556
- }
557
- function lastMessageWasUser(messages) {
558
- return messages.length > 0 && messages[messages.length - 1].role === "user";
559
- }
560
686
 
561
687
  // src/workersai-image-model.ts
562
688
  var WorkersAIImageModel = class {
@@ -682,7 +808,23 @@ function createWorkersAI(options) {
682
808
  provider.imageModel = createImageModel;
683
809
  return provider;
684
810
  }
811
+ function createAutoRAG(options) {
812
+ const binding = options.binding;
813
+ const createChatModel = (settings = {}) => new AutoRAGChatLanguageModel("@cf/meta/llama-3.3-70b-instruct-fp8-fast", settings, {
814
+ provider: "autorag.chat",
815
+ binding
816
+ });
817
+ const provider = (settings) => {
818
+ if (new.target) {
819
+ throw new Error("The WorkersAI model function cannot be called with the new keyword.");
820
+ }
821
+ return createChatModel(settings);
822
+ };
823
+ provider.chat = createChatModel;
824
+ return provider;
825
+ }
685
826
  export {
827
+ createAutoRAG,
686
828
  createWorkersAI
687
829
  };
688
830
  //# sourceMappingURL=index.js.map