workers-ai-provider 0.3.2 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.ts CHANGED
@@ -4,17 +4,12 @@ type StringLike = string | {
4
4
  toString(): string;
5
5
  };
6
6
 
7
- type WorkersAIChatSettings = {
7
+ type AutoRAGChatSettings = {
8
8
  /**
9
9
  * Whether to inject a safety prompt before all conversations.
10
10
  * Defaults to `false`.
11
11
  */
12
12
  safePrompt?: boolean;
13
- /**
14
- * Optionally set Cloudflare AI Gateway options.
15
- * @deprecated
16
- */
17
- gateway?: GatewayOptions;
18
13
  } & {
19
14
  /**
20
15
  * Passthrough settings that are provided directly to the run function.
@@ -31,6 +26,42 @@ type value2key<T, V> = {
31
26
  [K in keyof T]: T[K] extends V ? K : never;
32
27
  }[keyof T];
33
28
 
29
+ type AutoRAGChatConfig = {
30
+ provider: string;
31
+ binding: AutoRAG;
32
+ gateway?: GatewayOptions;
33
+ };
34
+ declare class AutoRAGChatLanguageModel implements LanguageModelV1 {
35
+ readonly specificationVersion = "v1";
36
+ readonly defaultObjectGenerationMode = "json";
37
+ readonly modelId: TextGenerationModels;
38
+ readonly settings: AutoRAGChatSettings;
39
+ private readonly config;
40
+ constructor(modelId: TextGenerationModels, settings: AutoRAGChatSettings, config: AutoRAGChatConfig);
41
+ get provider(): string;
42
+ private getArgs;
43
+ doGenerate(options: Parameters<LanguageModelV1["doGenerate"]>[0]): Promise<Awaited<ReturnType<LanguageModelV1["doGenerate"]>>>;
44
+ doStream(options: Parameters<LanguageModelV1["doStream"]>[0]): Promise<Awaited<ReturnType<LanguageModelV1["doStream"]>>>;
45
+ }
46
+
47
+ type WorkersAIChatSettings = {
48
+ /**
49
+ * Whether to inject a safety prompt before all conversations.
50
+ * Defaults to `false`.
51
+ */
52
+ safePrompt?: boolean;
53
+ /**
54
+ * Optionally set Cloudflare AI Gateway options.
55
+ * @deprecated
56
+ */
57
+ gateway?: GatewayOptions;
58
+ } & {
59
+ /**
60
+ * Passthrough settings that are provided directly to the run function.
61
+ */
62
+ [key: string]: StringLike;
63
+ };
64
+
34
65
  type WorkersAIChatConfig = {
35
66
  provider: string;
36
67
  binding: Ai;
@@ -111,5 +142,19 @@ interface WorkersAI {
111
142
  * Create a Workers AI provider instance.
112
143
  */
113
144
  declare function createWorkersAI(options: WorkersAISettings): WorkersAI;
145
+ type AutoRAGSettings = {
146
+ binding: AutoRAG;
147
+ };
148
+ interface AutoRAGProvider {
149
+ (options?: AutoRAGChatSettings): AutoRAGChatLanguageModel;
150
+ /**
151
+ * Creates a model for text generation.
152
+ **/
153
+ chat(settings?: AutoRAGChatSettings): AutoRAGChatLanguageModel;
154
+ }
155
+ /**
156
+ * Create a Workers AI provider instance.
157
+ */
158
+ declare function createAutoRAG(options: AutoRAGSettings): AutoRAGProvider;
114
159
 
115
- export { type WorkersAI, type WorkersAISettings, createWorkersAI };
160
+ export { type AutoRAGProvider, type AutoRAGSettings, type WorkersAI, type WorkersAISettings, createAutoRAG, createWorkersAI };
package/dist/index.js CHANGED
@@ -9,51 +9,7 @@ var __privateGet = (obj, member, getter) => (__accessCheck(obj, member, "read fr
9
9
  var __privateAdd = (obj, member, value) => member.has(obj) ? __typeError("Cannot add the same private member more than once") : member instanceof WeakSet ? member.add(obj) : member.set(obj, value);
10
10
  var __privateSet = (obj, member, value, setter) => (__accessCheck(obj, member, "write to private field"), setter ? setter.call(obj, value) : member.set(obj, value), value);
11
11
 
12
- // src/utils.ts
13
- function createRun(config) {
14
- const { accountId, apiKey } = config;
15
- return async function run(model, inputs, options) {
16
- const { gateway, prefix, extraHeaders, returnRawResponse, ...passthroughOptions } = options || {};
17
- const urlParams = new URLSearchParams();
18
- for (const [key, value] of Object.entries(passthroughOptions)) {
19
- try {
20
- const valueStr = value.toString();
21
- if (!valueStr) {
22
- continue;
23
- }
24
- urlParams.append(key, valueStr);
25
- } catch (error) {
26
- throw new Error(
27
- `Value for option '${key}' is not able to be coerced into a string.`
28
- );
29
- }
30
- }
31
- const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}${urlParams ? `?${urlParams}` : ""}`;
32
- const headers = {
33
- "Content-Type": "application/json",
34
- Authorization: `Bearer ${apiKey}`
35
- };
36
- const body = JSON.stringify(inputs);
37
- const response = await fetch(url, {
38
- method: "POST",
39
- headers,
40
- body
41
- });
42
- if (returnRawResponse) {
43
- return response;
44
- }
45
- if (inputs.stream === true) {
46
- if (response.body) {
47
- return response.body;
48
- }
49
- throw new Error("No readable body available for streaming.");
50
- }
51
- const data = await response.json();
52
- return data.result;
53
- };
54
- }
55
-
56
- // src/workersai-chat-language-model.ts
12
+ // src/autorag-chat-language-model.ts
57
13
  import {
58
14
  UnsupportedFunctionalityError
59
15
  } from "@ai-sdk/provider";
@@ -157,6 +113,18 @@ function convertToWorkersAIChatMessages(prompt) {
157
113
  return { messages, images };
158
114
  }
159
115
 
116
+ // src/map-workersai-usage.ts
117
+ function mapWorkersAIUsage(output) {
118
+ const usage = output.usage ?? {
119
+ prompt_tokens: 0,
120
+ completion_tokens: 0
121
+ };
122
+ return {
123
+ promptTokens: usage.prompt_tokens,
124
+ completionTokens: usage.completion_tokens
125
+ };
126
+ }
127
+
160
128
  // ../../node_modules/.pnpm/fetch-event-stream@0.1.5/node_modules/fetch-event-stream/esm/deps/jsr.io/@std/streams/0.221.0/text_line_stream.js
161
129
  var _currentLine;
162
130
  var TextLineStream = class extends TransformStream {
@@ -250,19 +218,274 @@ async function* events(res, signal) {
250
218
  }
251
219
  }
252
220
 
253
- // src/map-workersai-usage.ts
254
- function mapWorkersAIUsage(output) {
255
- const usage = output.usage ?? {
256
- prompt_tokens: 0,
257
- completion_tokens: 0
258
- };
259
- return {
260
- promptTokens: usage.prompt_tokens,
261
- completionTokens: usage.completion_tokens
221
+ // src/streaming.ts
222
+ function getMappedStream(response) {
223
+ const chunkEvent = events(response);
224
+ let usage = { promptTokens: 0, completionTokens: 0 };
225
+ return new ReadableStream({
226
+ async start(controller) {
227
+ for await (const event of chunkEvent) {
228
+ if (!event.data) {
229
+ continue;
230
+ }
231
+ if (event.data === "[DONE]") {
232
+ break;
233
+ }
234
+ const chunk = JSON.parse(event.data);
235
+ if (chunk.usage) {
236
+ usage = mapWorkersAIUsage(chunk);
237
+ }
238
+ chunk.response?.length && controller.enqueue({
239
+ type: "text-delta",
240
+ textDelta: chunk.response
241
+ });
242
+ }
243
+ controller.enqueue({
244
+ type: "finish",
245
+ finishReason: "stop",
246
+ usage
247
+ });
248
+ controller.close();
249
+ }
250
+ });
251
+ }
252
+
253
+ // src/utils.ts
254
+ function createRun(config) {
255
+ const { accountId, apiKey } = config;
256
+ return async function run(model, inputs, options) {
257
+ const { gateway, prefix, extraHeaders, returnRawResponse, ...passthroughOptions } = options || {};
258
+ const urlParams = new URLSearchParams();
259
+ for (const [key, value] of Object.entries(passthroughOptions)) {
260
+ try {
261
+ const valueStr = value.toString();
262
+ if (!valueStr) {
263
+ continue;
264
+ }
265
+ urlParams.append(key, valueStr);
266
+ } catch (error) {
267
+ throw new Error(
268
+ `Value for option '${key}' is not able to be coerced into a string.`
269
+ );
270
+ }
271
+ }
272
+ const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}${urlParams ? `?${urlParams}` : ""}`;
273
+ const headers = {
274
+ "Content-Type": "application/json",
275
+ Authorization: `Bearer ${apiKey}`
276
+ };
277
+ const body = JSON.stringify(inputs);
278
+ const response = await fetch(url, {
279
+ method: "POST",
280
+ headers,
281
+ body
282
+ });
283
+ if (returnRawResponse) {
284
+ return response;
285
+ }
286
+ if (inputs.stream === true) {
287
+ if (response.body) {
288
+ return response.body;
289
+ }
290
+ throw new Error("No readable body available for streaming.");
291
+ }
292
+ const data = await response.json();
293
+ return data.result;
262
294
  };
263
295
  }
296
+ function prepareToolsAndToolChoice(mode) {
297
+ const tools = mode.tools?.length ? mode.tools : void 0;
298
+ if (tools == null) {
299
+ return { tools: void 0, tool_choice: void 0 };
300
+ }
301
+ const mappedTools = tools.map((tool) => ({
302
+ type: "function",
303
+ function: {
304
+ name: tool.name,
305
+ // @ts-expect-error - description is not a property of tool
306
+ description: tool.description,
307
+ // @ts-expect-error - parameters is not a property of tool
308
+ parameters: tool.parameters
309
+ }
310
+ }));
311
+ const toolChoice = mode.toolChoice;
312
+ if (toolChoice == null) {
313
+ return { tools: mappedTools, tool_choice: void 0 };
314
+ }
315
+ const type = toolChoice.type;
316
+ switch (type) {
317
+ case "auto":
318
+ return { tools: mappedTools, tool_choice: type };
319
+ case "none":
320
+ return { tools: mappedTools, tool_choice: type };
321
+ case "required":
322
+ return { tools: mappedTools, tool_choice: "any" };
323
+ // workersAI does not support tool mode directly,
324
+ // so we filter the tools and force the tool choice through 'any'
325
+ case "tool":
326
+ return {
327
+ tools: mappedTools.filter((tool) => tool.function.name === toolChoice.toolName),
328
+ tool_choice: "any"
329
+ };
330
+ default: {
331
+ const exhaustiveCheck = type;
332
+ throw new Error(`Unsupported tool choice type: ${exhaustiveCheck}`);
333
+ }
334
+ }
335
+ }
336
+ function lastMessageWasUser(messages) {
337
+ return messages.length > 0 && messages[messages.length - 1].role === "user";
338
+ }
339
+ function processToolCalls(output) {
340
+ if (output.tool_calls && Array.isArray(output.tool_calls)) {
341
+ return output.tool_calls.map((toolCall) => {
342
+ if (toolCall.function && toolCall.id) {
343
+ return {
344
+ toolCallType: "function",
345
+ toolCallId: toolCall.id,
346
+ toolName: toolCall.function.name,
347
+ args: typeof toolCall.function.arguments === "string" ? toolCall.function.arguments : JSON.stringify(toolCall.function.arguments || {})
348
+ };
349
+ }
350
+ return {
351
+ toolCallType: "function",
352
+ toolCallId: toolCall.name,
353
+ toolName: toolCall.name,
354
+ args: typeof toolCall.arguments === "string" ? toolCall.arguments : JSON.stringify(toolCall.arguments || {})
355
+ };
356
+ });
357
+ }
358
+ return [];
359
+ }
360
+
361
+ // src/autorag-chat-language-model.ts
362
+ var AutoRAGChatLanguageModel = class {
363
+ constructor(modelId, settings, config) {
364
+ __publicField(this, "specificationVersion", "v1");
365
+ __publicField(this, "defaultObjectGenerationMode", "json");
366
+ __publicField(this, "modelId");
367
+ __publicField(this, "settings");
368
+ __publicField(this, "config");
369
+ this.modelId = modelId;
370
+ this.settings = settings;
371
+ this.config = config;
372
+ }
373
+ get provider() {
374
+ return this.config.provider;
375
+ }
376
+ getArgs({
377
+ mode,
378
+ prompt,
379
+ frequencyPenalty,
380
+ presencePenalty
381
+ }) {
382
+ const type = mode.type;
383
+ const warnings = [];
384
+ if (frequencyPenalty != null) {
385
+ warnings.push({
386
+ type: "unsupported-setting",
387
+ setting: "frequencyPenalty"
388
+ });
389
+ }
390
+ if (presencePenalty != null) {
391
+ warnings.push({
392
+ type: "unsupported-setting",
393
+ setting: "presencePenalty"
394
+ });
395
+ }
396
+ const baseArgs = {
397
+ // model id:
398
+ model: this.modelId,
399
+ // messages:
400
+ messages: convertToWorkersAIChatMessages(prompt)
401
+ };
402
+ switch (type) {
403
+ case "regular": {
404
+ return {
405
+ args: { ...baseArgs, ...prepareToolsAndToolChoice(mode) },
406
+ warnings
407
+ };
408
+ }
409
+ case "object-json": {
410
+ return {
411
+ args: {
412
+ ...baseArgs,
413
+ response_format: {
414
+ type: "json_schema",
415
+ json_schema: mode.schema
416
+ },
417
+ tools: void 0
418
+ },
419
+ warnings
420
+ };
421
+ }
422
+ case "object-tool": {
423
+ return {
424
+ args: {
425
+ ...baseArgs,
426
+ tool_choice: "any",
427
+ tools: [{ type: "function", function: mode.tool }]
428
+ },
429
+ warnings
430
+ };
431
+ }
432
+ // @ts-expect-error - this is unreachable code
433
+ // TODO: fixme
434
+ case "object-grammar": {
435
+ throw new UnsupportedFunctionalityError({
436
+ functionality: "object-grammar mode"
437
+ });
438
+ }
439
+ default: {
440
+ const exhaustiveCheck = type;
441
+ throw new Error(`Unsupported type: ${exhaustiveCheck}`);
442
+ }
443
+ }
444
+ }
445
+ async doGenerate(options) {
446
+ const { args, warnings } = this.getArgs(options);
447
+ const { messages } = convertToWorkersAIChatMessages(options.prompt);
448
+ const output = await this.config.binding.aiSearch({
449
+ query: messages.map(({ content, role }) => `${role}: ${content}`).join("\n\n")
450
+ });
451
+ return {
452
+ text: output.response,
453
+ toolCalls: processToolCalls(output),
454
+ finishReason: "stop",
455
+ // TODO: mapWorkersAIFinishReason(response.finish_reason),
456
+ rawCall: { rawPrompt: args.messages, rawSettings: args },
457
+ usage: mapWorkersAIUsage(output),
458
+ warnings,
459
+ sources: output.data.map(({ file_id, filename, score }) => ({
460
+ id: file_id,
461
+ sourceType: "url",
462
+ url: filename,
463
+ providerMetadata: {
464
+ attributes: { score }
465
+ }
466
+ }))
467
+ };
468
+ }
469
+ async doStream(options) {
470
+ const { args, warnings } = this.getArgs(options);
471
+ const { messages } = convertToWorkersAIChatMessages(options.prompt);
472
+ const query = messages.map(({ content, role }) => `${role}: ${content}`).join("\n\n");
473
+ const response = await this.config.binding.aiSearch({
474
+ query,
475
+ stream: true
476
+ });
477
+ return {
478
+ stream: getMappedStream(response),
479
+ rawCall: { rawPrompt: args.messages, rawSettings: args },
480
+ warnings
481
+ };
482
+ }
483
+ };
264
484
 
265
485
  // src/workersai-chat-language-model.ts
486
+ import {
487
+ UnsupportedFunctionalityError as UnsupportedFunctionalityError2
488
+ } from "@ai-sdk/provider";
266
489
  var WorkersAIChatLanguageModel = class {
267
490
  constructor(modelId, settings, config) {
268
491
  __publicField(this, "specificationVersion", "v1");
@@ -344,7 +567,7 @@ var WorkersAIChatLanguageModel = class {
344
567
  // @ts-expect-error - this is unreachable code
345
568
  // TODO: fixme
346
569
  case "object-grammar": {
347
- throw new UnsupportedFunctionalityError({
570
+ throw new UnsupportedFunctionalityError2({
348
571
  functionality: "object-grammar mode"
349
572
  });
350
573
  }
@@ -357,9 +580,7 @@ var WorkersAIChatLanguageModel = class {
357
580
  async doGenerate(options) {
358
581
  const { args, warnings } = this.getArgs(options);
359
582
  const { gateway, safePrompt, ...passthroughOptions } = this.settings;
360
- const { messages, images } = convertToWorkersAIChatMessages(
361
- options.prompt
362
- );
583
+ const { messages, images } = convertToWorkersAIChatMessages(options.prompt);
363
584
  if (images.length !== 0 && images.length !== 1) {
364
585
  throw new Error("Multiple images are not yet supported as input");
365
586
  }
@@ -385,12 +606,7 @@ var WorkersAIChatLanguageModel = class {
385
606
  }
386
607
  return {
387
608
  text: typeof output.response === "object" && output.response !== null ? JSON.stringify(output.response) : output.response,
388
- toolCalls: output.tool_calls?.map((toolCall) => ({
389
- toolCallType: "function",
390
- toolCallId: toolCall.name,
391
- toolName: toolCall.name,
392
- args: JSON.stringify(toolCall.arguments || {})
393
- })),
609
+ toolCalls: processToolCalls(output),
394
610
  finishReason: "stop",
395
611
  // TODO: mapWorkersAIFinishReason(response.finish_reason),
396
612
  rawCall: { rawPrompt: messages, rawSettings: args },
@@ -400,9 +616,7 @@ var WorkersAIChatLanguageModel = class {
400
616
  }
401
617
  async doStream(options) {
402
618
  const { args, warnings } = this.getArgs(options);
403
- const { messages, images } = convertToWorkersAIChatMessages(
404
- options.prompt
405
- );
619
+ const { messages, images } = convertToWorkersAIChatMessages(options.prompt);
406
620
  if (args.tools?.length && lastMessageWasUser(messages)) {
407
621
  const response2 = await this.doGenerate(options);
408
622
  if (response2 instanceof ReadableStream) {
@@ -462,85 +676,13 @@ var WorkersAIChatLanguageModel = class {
462
676
  if (!(response instanceof ReadableStream)) {
463
677
  throw new Error("This shouldn't happen");
464
678
  }
465
- const chunkEvent = events(new Response(response));
466
- let usage = { promptTokens: 0, completionTokens: 0 };
467
679
  return {
468
- stream: new ReadableStream({
469
- async start(controller) {
470
- for await (const event of chunkEvent) {
471
- if (!event.data) {
472
- continue;
473
- }
474
- if (event.data === "[DONE]") {
475
- break;
476
- }
477
- const chunk = JSON.parse(event.data);
478
- if (chunk.usage) {
479
- usage = mapWorkersAIUsage(chunk);
480
- }
481
- chunk.response?.length && controller.enqueue({
482
- type: "text-delta",
483
- textDelta: chunk.response
484
- });
485
- }
486
- controller.enqueue({
487
- type: "finish",
488
- finishReason: "stop",
489
- usage
490
- });
491
- controller.close();
492
- }
493
- }),
680
+ stream: getMappedStream(new Response(response)),
494
681
  rawCall: { rawPrompt: messages, rawSettings: args },
495
682
  warnings
496
683
  };
497
684
  }
498
685
  };
499
- function prepareToolsAndToolChoice(mode) {
500
- const tools = mode.tools?.length ? mode.tools : void 0;
501
- if (tools == null) {
502
- return { tools: void 0, tool_choice: void 0 };
503
- }
504
- const mappedTools = tools.map((tool) => ({
505
- type: "function",
506
- function: {
507
- name: tool.name,
508
- // @ts-expect-error - description is not a property of tool
509
- description: tool.description,
510
- // @ts-expect-error - parameters is not a property of tool
511
- parameters: tool.parameters
512
- }
513
- }));
514
- const toolChoice = mode.toolChoice;
515
- if (toolChoice == null) {
516
- return { tools: mappedTools, tool_choice: void 0 };
517
- }
518
- const type = toolChoice.type;
519
- switch (type) {
520
- case "auto":
521
- return { tools: mappedTools, tool_choice: type };
522
- case "none":
523
- return { tools: mappedTools, tool_choice: type };
524
- case "required":
525
- return { tools: mappedTools, tool_choice: "any" };
526
- // workersAI does not support tool mode directly,
527
- // so we filter the tools and force the tool choice through 'any'
528
- case "tool":
529
- return {
530
- tools: mappedTools.filter(
531
- (tool) => tool.function.name === toolChoice.toolName
532
- ),
533
- tool_choice: "any"
534
- };
535
- default: {
536
- const exhaustiveCheck = type;
537
- throw new Error(`Unsupported tool choice type: ${exhaustiveCheck}`);
538
- }
539
- }
540
- }
541
- function lastMessageWasUser(messages) {
542
- return messages.length > 0 && messages[messages.length - 1].role === "user";
543
- }
544
686
 
545
687
  // src/workersai-image-model.ts
546
688
  var WorkersAIImageModel = class {
@@ -666,7 +808,23 @@ function createWorkersAI(options) {
666
808
  provider.imageModel = createImageModel;
667
809
  return provider;
668
810
  }
811
+ function createAutoRAG(options) {
812
+ const binding = options.binding;
813
+ const createChatModel = (settings = {}) => new AutoRAGChatLanguageModel("@cf/meta/llama-3.3-70b-instruct-fp8-fast", settings, {
814
+ provider: "autorag.chat",
815
+ binding
816
+ });
817
+ const provider = (settings) => {
818
+ if (new.target) {
819
+ throw new Error("The WorkersAI model function cannot be called with the new keyword.");
820
+ }
821
+ return createChatModel(settings);
822
+ };
823
+ provider.chat = createChatModel;
824
+ return provider;
825
+ }
669
826
  export {
827
+ createAutoRAG,
670
828
  createWorkersAI
671
829
  };
672
830
  //# sourceMappingURL=index.js.map