workers-ai-provider 0.6.3 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.ts CHANGED
@@ -90,7 +90,6 @@ type WorkersAIChatSettings = {
90
90
  safePrompt?: boolean;
91
91
  /**
92
92
  * Optionally set Cloudflare AI Gateway options.
93
- * @deprecated
94
93
  */
95
94
  gateway?: GatewayOptions;
96
95
  } & {
package/dist/index.js CHANGED
@@ -21,12 +21,11 @@ function convertToWorkersAIChatMessages(prompt) {
21
21
  for (const { role, content } of prompt) {
22
22
  switch (role) {
23
23
  case "system": {
24
- messages.push({ role: "system", content });
24
+ messages.push({ content, role: "system" });
25
25
  break;
26
26
  }
27
27
  case "user": {
28
28
  messages.push({
29
- role: "user",
30
29
  content: content.map((part) => {
31
30
  switch (part.type) {
32
31
  case "text": {
@@ -35,15 +34,16 @@ function convertToWorkersAIChatMessages(prompt) {
35
34
  case "image": {
36
35
  if (part.image instanceof Uint8Array) {
37
36
  images.push({
38
- mimeType: part.mimeType,
39
37
  image: part.image,
38
+ mimeType: part.mimeType,
40
39
  providerMetadata: part.providerMetadata
41
40
  });
42
41
  }
43
42
  return "";
44
43
  }
45
44
  }
46
- }).join("\n")
45
+ }).join("\n"),
46
+ role: "user"
47
47
  });
48
48
  break;
49
49
  }
@@ -56,34 +56,38 @@ function convertToWorkersAIChatMessages(prompt) {
56
56
  text += part.text;
57
57
  break;
58
58
  }
59
+ case "reasoning": {
60
+ text += part.text;
61
+ break;
62
+ }
59
63
  case "tool-call": {
60
64
  text = JSON.stringify({
61
65
  name: part.toolName,
62
66
  parameters: part.args
63
67
  });
64
68
  toolCalls.push({
65
- id: part.toolCallId,
66
- type: "function",
67
69
  function: {
68
- name: part.toolName,
69
- arguments: JSON.stringify(part.args)
70
- }
70
+ arguments: JSON.stringify(part.args),
71
+ name: part.toolName
72
+ },
73
+ id: part.toolCallId,
74
+ type: "function"
71
75
  });
72
76
  break;
73
77
  }
74
78
  default: {
75
79
  const exhaustiveCheck = part;
76
- throw new Error(`Unsupported part: ${exhaustiveCheck}`);
80
+ throw new Error(`Unsupported part type: ${exhaustiveCheck.type}`);
77
81
  }
78
82
  }
79
83
  }
80
84
  messages.push({
81
- role: "assistant",
82
85
  content: text,
86
+ role: "assistant",
83
87
  tool_calls: toolCalls.length > 0 ? toolCalls.map(({ function: { name, arguments: args } }) => ({
88
+ function: { arguments: args, name },
84
89
  id: "null",
85
- type: "function",
86
- function: { name, arguments: args }
90
+ type: "function"
87
91
  })) : void 0
88
92
  });
89
93
  break;
@@ -91,9 +95,9 @@ function convertToWorkersAIChatMessages(prompt) {
91
95
  case "tool": {
92
96
  for (const toolResponse of content) {
93
97
  messages.push({
94
- role: "tool",
98
+ content: JSON.stringify(toolResponse.result),
95
99
  name: toolResponse.toolName,
96
- content: JSON.stringify(toolResponse.result)
100
+ role: "tool"
97
101
  });
98
102
  }
99
103
  break;
@@ -104,18 +108,18 @@ function convertToWorkersAIChatMessages(prompt) {
104
108
  }
105
109
  }
106
110
  }
107
- return { messages, images };
111
+ return { images, messages };
108
112
  }
109
113
 
110
114
  // src/map-workersai-usage.ts
111
115
  function mapWorkersAIUsage(output) {
112
116
  const usage = output.usage ?? {
113
- prompt_tokens: 0,
114
- completion_tokens: 0
117
+ completion_tokens: 0,
118
+ prompt_tokens: 0
115
119
  };
116
120
  return {
117
- promptTokens: usage.prompt_tokens,
118
- completionTokens: usage.completion_tokens
121
+ completionTokens: usage.completion_tokens,
122
+ promptTokens: usage.prompt_tokens
119
123
  };
120
124
  }
121
125
 
@@ -225,7 +229,7 @@ function createRun(config) {
225
229
  continue;
226
230
  }
227
231
  urlParams.append(key, valueStr);
228
- } catch (error) {
232
+ } catch (_error) {
229
233
  throw new Error(
230
234
  `Value for option '${key}' is not able to be coerced into a string.`
231
235
  );
@@ -233,14 +237,14 @@ function createRun(config) {
233
237
  }
234
238
  const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}${urlParams ? `?${urlParams}` : ""}`;
235
239
  const headers = {
236
- "Content-Type": "application/json",
237
- Authorization: `Bearer ${apiKey}`
240
+ Authorization: `Bearer ${apiKey}`,
241
+ "Content-Type": "application/json"
238
242
  };
239
243
  const body = JSON.stringify(inputs);
240
244
  const response = await fetch(url, {
241
- method: "POST",
245
+ body,
242
246
  headers,
243
- body
247
+ method: "POST"
244
248
  });
245
249
  if (returnRawResponse) {
246
250
  return response;
@@ -258,36 +262,36 @@ function createRun(config) {
258
262
  function prepareToolsAndToolChoice(mode) {
259
263
  const tools = mode.tools?.length ? mode.tools : void 0;
260
264
  if (tools == null) {
261
- return { tools: void 0, tool_choice: void 0 };
265
+ return { tool_choice: void 0, tools: void 0 };
262
266
  }
263
267
  const mappedTools = tools.map((tool) => ({
264
- type: "function",
265
268
  function: {
266
- name: tool.name,
267
269
  // @ts-expect-error - description is not a property of tool
268
270
  description: tool.description,
271
+ name: tool.name,
269
272
  // @ts-expect-error - parameters is not a property of tool
270
273
  parameters: tool.parameters
271
- }
274
+ },
275
+ type: "function"
272
276
  }));
273
277
  const toolChoice = mode.toolChoice;
274
278
  if (toolChoice == null) {
275
- return { tools: mappedTools, tool_choice: void 0 };
279
+ return { tool_choice: void 0, tools: mappedTools };
276
280
  }
277
281
  const type = toolChoice.type;
278
282
  switch (type) {
279
283
  case "auto":
280
- return { tools: mappedTools, tool_choice: type };
284
+ return { tool_choice: type, tools: mappedTools };
281
285
  case "none":
282
- return { tools: mappedTools, tool_choice: type };
286
+ return { tool_choice: type, tools: mappedTools };
283
287
  case "required":
284
- return { tools: mappedTools, tool_choice: "any" };
288
+ return { tool_choice: "any", tools: mappedTools };
285
289
  // workersAI does not support tool mode directly,
286
290
  // so we filter the tools and force the tool choice through 'any'
287
291
  case "tool":
288
292
  return {
289
- tools: mappedTools.filter((tool) => tool.function.name === toolChoice.toolName),
290
- tool_choice: "any"
293
+ tool_choice: "any",
294
+ tools: mappedTools.filter((tool) => tool.function.name === toolChoice.toolName)
291
295
  };
292
296
  default: {
293
297
  const exhaustiveCheck = type;
@@ -304,12 +308,12 @@ function mergePartialToolCalls(partialCalls) {
304
308
  const index = partialCall.index;
305
309
  if (!mergedCallsByIndex[index]) {
306
310
  mergedCallsByIndex[index] = {
307
- id: partialCall.id || "",
308
- type: partialCall.type || "",
309
311
  function: {
310
- name: partialCall.function?.name || "",
311
- arguments: ""
312
- }
312
+ arguments: "",
313
+ name: partialCall.function?.name || ""
314
+ },
315
+ id: partialCall.id || "",
316
+ type: partialCall.type || ""
313
317
  };
314
318
  } else {
315
319
  if (partialCall.id) {
@@ -331,17 +335,17 @@ function mergePartialToolCalls(partialCalls) {
331
335
  function processToolCall(toolCall) {
332
336
  if (toolCall.function && toolCall.id) {
333
337
  return {
334
- toolCallType: "function",
338
+ args: typeof toolCall.function.arguments === "string" ? toolCall.function.arguments : JSON.stringify(toolCall.function.arguments || {}),
335
339
  toolCallId: toolCall.id,
336
- toolName: toolCall.function.name,
337
- args: typeof toolCall.function.arguments === "string" ? toolCall.function.arguments : JSON.stringify(toolCall.function.arguments || {})
340
+ toolCallType: "function",
341
+ toolName: toolCall.function.name
338
342
  };
339
343
  }
340
344
  return {
341
- toolCallType: "function",
345
+ args: typeof toolCall.arguments === "string" ? toolCall.arguments : JSON.stringify(toolCall.arguments || {}),
342
346
  toolCallId: toolCall.name,
343
- toolName: toolCall.name,
344
- args: typeof toolCall.arguments === "string" ? toolCall.arguments : JSON.stringify(toolCall.arguments || {})
347
+ toolCallType: "function",
348
+ toolName: toolCall.name
345
349
  };
346
350
  }
347
351
  function processToolCalls(output) {
@@ -351,6 +355,12 @@ function processToolCalls(output) {
351
355
  return processedToolCall;
352
356
  });
353
357
  }
358
+ if (output?.choices?.[0]?.message?.tool_calls && Array.isArray(output.choices[0].message.tool_calls)) {
359
+ return output.choices[0].message.tool_calls.map((toolCall) => {
360
+ const processedToolCall = processToolCall(toolCall);
361
+ return processedToolCall;
362
+ });
363
+ }
354
364
  return [];
355
365
  }
356
366
  function processPartialToolCalls(partialToolCalls) {
@@ -361,7 +371,7 @@ function processPartialToolCalls(partialToolCalls) {
361
371
  // src/streaming.ts
362
372
  function getMappedStream(response) {
363
373
  const chunkEvent = events(response);
364
- let usage = { promptTokens: 0, completionTokens: 0 };
374
+ let usage = { completionTokens: 0, promptTokens: 0 };
365
375
  const partialToolCalls = [];
366
376
  return new ReadableStream({
367
377
  async start(controller) {
@@ -380,8 +390,12 @@ function getMappedStream(response) {
380
390
  partialToolCalls.push(...chunk.tool_calls);
381
391
  }
382
392
  chunk.response?.length && controller.enqueue({
383
- type: "text-delta",
384
- textDelta: chunk.response
393
+ textDelta: chunk.response,
394
+ type: "text-delta"
395
+ });
396
+ chunk?.choices?.[0]?.delta?.reasoning_content?.length && controller.enqueue({
397
+ type: "reasoning",
398
+ textDelta: chunk.choices[0].delta.reasoning_content
385
399
  });
386
400
  }
387
401
  if (partialToolCalls.length > 0) {
@@ -394,8 +408,8 @@ function getMappedStream(response) {
394
408
  });
395
409
  }
396
410
  controller.enqueue({
397
- type: "finish",
398
411
  finishReason: "stop",
412
+ type: "finish",
399
413
  usage
400
414
  });
401
415
  controller.close();
@@ -428,21 +442,21 @@ var AutoRAGChatLanguageModel = class {
428
442
  const warnings = [];
429
443
  if (frequencyPenalty != null) {
430
444
  warnings.push({
431
- type: "unsupported-setting",
432
- setting: "frequencyPenalty"
445
+ setting: "frequencyPenalty",
446
+ type: "unsupported-setting"
433
447
  });
434
448
  }
435
449
  if (presencePenalty != null) {
436
450
  warnings.push({
437
- type: "unsupported-setting",
438
- setting: "presencePenalty"
451
+ setting: "presencePenalty",
452
+ type: "unsupported-setting"
439
453
  });
440
454
  }
441
455
  const baseArgs = {
442
- // model id:
443
- model: this.modelId,
444
456
  // messages:
445
- messages: convertToWorkersAIChatMessages(prompt)
457
+ messages: convertToWorkersAIChatMessages(prompt),
458
+ // model id:
459
+ model: this.modelId
446
460
  };
447
461
  switch (type) {
448
462
  case "regular": {
@@ -456,8 +470,8 @@ var AutoRAGChatLanguageModel = class {
456
470
  args: {
457
471
  ...baseArgs,
458
472
  response_format: {
459
- type: "json_schema",
460
- json_schema: mode.schema
473
+ json_schema: mode.schema,
474
+ type: "json_schema"
461
475
  },
462
476
  tools: void 0
463
477
  },
@@ -469,7 +483,7 @@ var AutoRAGChatLanguageModel = class {
469
483
  args: {
470
484
  ...baseArgs,
471
485
  tool_choice: "any",
472
- tools: [{ type: "function", function: mode.tool }]
486
+ tools: [{ function: mode.tool, type: "function" }]
473
487
  },
474
488
  warnings
475
489
  };
@@ -494,21 +508,21 @@ var AutoRAGChatLanguageModel = class {
494
508
  query: messages.map(({ content, role }) => `${role}: ${content}`).join("\n\n")
495
509
  });
496
510
  return {
497
- text: output.response,
498
- toolCalls: processToolCalls(output),
499
511
  finishReason: "stop",
500
- // TODO: mapWorkersAIFinishReason(response.finish_reason),
501
512
  rawCall: { rawPrompt: args.messages, rawSettings: args },
502
- usage: mapWorkersAIUsage(output),
503
- warnings,
504
513
  sources: output.data.map(({ file_id, filename, score }) => ({
505
514
  id: file_id,
506
- sourceType: "url",
507
- url: filename,
508
515
  providerMetadata: {
509
516
  attributes: { score }
510
- }
511
- }))
517
+ },
518
+ sourceType: "url",
519
+ url: filename
520
+ })),
521
+ // TODO: mapWorkersAIFinishReason(response.finish_reason),
522
+ text: output.response,
523
+ toolCalls: processToolCalls(output),
524
+ usage: mapWorkersAIUsage(output),
525
+ warnings
512
526
  };
513
527
  }
514
528
  async doStream(options) {
@@ -520,8 +534,8 @@ var AutoRAGChatLanguageModel = class {
520
534
  stream: true
521
535
  });
522
536
  return {
523
- stream: getMappedStream(response),
524
537
  rawCall: { rawPrompt: args.messages, rawSettings: args },
538
+ stream: getMappedStream(response),
525
539
  warnings
526
540
  };
527
541
  }
@@ -561,9 +575,9 @@ var WorkersAIEmbeddingModel = class {
561
575
  }) {
562
576
  if (values.length > this.maxEmbeddingsPerCall) {
563
577
  throw new TooManyEmbeddingValuesForCallError({
564
- provider: this.provider,
565
- modelId: this.modelId,
566
578
  maxEmbeddingsPerCall: this.maxEmbeddingsPerCall,
579
+ modelId: this.modelId,
580
+ provider: this.provider,
567
581
  values
568
582
  });
569
583
  }
@@ -650,26 +664,26 @@ var WorkersAIChatLanguageModel = class {
650
664
  const warnings = [];
651
665
  if (frequencyPenalty != null) {
652
666
  warnings.push({
653
- type: "unsupported-setting",
654
- setting: "frequencyPenalty"
667
+ setting: "frequencyPenalty",
668
+ type: "unsupported-setting"
655
669
  });
656
670
  }
657
671
  if (presencePenalty != null) {
658
672
  warnings.push({
659
- type: "unsupported-setting",
660
- setting: "presencePenalty"
673
+ setting: "presencePenalty",
674
+ type: "unsupported-setting"
661
675
  });
662
676
  }
663
677
  const baseArgs = {
678
+ // standardized settings:
679
+ max_tokens: maxTokens,
664
680
  // model id:
665
681
  model: this.modelId,
682
+ random_seed: seed,
666
683
  // model specific settings:
667
684
  safe_prompt: this.settings.safePrompt,
668
- // standardized settings:
669
- max_tokens: maxTokens,
670
685
  temperature,
671
- top_p: topP,
672
- random_seed: seed
686
+ top_p: topP
673
687
  };
674
688
  switch (type) {
675
689
  case "regular": {
@@ -683,8 +697,8 @@ var WorkersAIChatLanguageModel = class {
683
697
  args: {
684
698
  ...baseArgs,
685
699
  response_format: {
686
- type: "json_schema",
687
- json_schema: mode.schema
700
+ json_schema: mode.schema,
701
+ type: "json_schema"
688
702
  },
689
703
  tools: void 0
690
704
  },
@@ -696,7 +710,7 @@ var WorkersAIChatLanguageModel = class {
696
710
  args: {
697
711
  ...baseArgs,
698
712
  tool_choice: "any",
699
- tools: [{ type: "function", function: mode.tool }]
713
+ tools: [{ function: mode.tool, type: "function" }]
700
714
  },
701
715
  warnings
702
716
  };
@@ -725,8 +739,8 @@ var WorkersAIChatLanguageModel = class {
725
739
  const output = await this.config.binding.run(
726
740
  args.model,
727
741
  {
728
- messages,
729
742
  max_tokens: args.max_tokens,
743
+ messages,
730
744
  temperature: args.temperature,
731
745
  tools: args.tools,
732
746
  top_p: args.top_p,
@@ -742,11 +756,13 @@ var WorkersAIChatLanguageModel = class {
742
756
  throw new Error("This shouldn't happen");
743
757
  }
744
758
  return {
745
- text: typeof output.response === "object" && output.response !== null ? JSON.stringify(output.response) : output.response,
746
- toolCalls: processToolCalls(output),
747
759
  finishReason: mapWorkersAIFinishReason(output),
748
760
  rawCall: { rawPrompt: messages, rawSettings: args },
749
761
  rawResponse: { body: output },
762
+ text: typeof output.response === "object" && output.response !== null ? JSON.stringify(output.response) : output.response,
763
+ toolCalls: processToolCalls(output),
764
+ // @ts-ignore: Missing types
765
+ reasoning: output?.choices?.[0]?.message?.reasoning_content,
750
766
  usage: mapWorkersAIUsage(output),
751
767
  warnings
752
768
  };
@@ -760,12 +776,13 @@ var WorkersAIChatLanguageModel = class {
760
776
  throw new Error("This shouldn't happen");
761
777
  }
762
778
  return {
779
+ rawCall: { rawPrompt: messages, rawSettings: args },
763
780
  stream: new ReadableStream({
764
781
  async start(controller) {
765
782
  if (response2.text) {
766
783
  controller.enqueue({
767
- type: "text-delta",
768
- textDelta: response2.text
784
+ textDelta: response2.text,
785
+ type: "text-delta"
769
786
  });
770
787
  }
771
788
  if (response2.toolCalls) {
@@ -776,15 +793,20 @@ var WorkersAIChatLanguageModel = class {
776
793
  });
777
794
  }
778
795
  }
796
+ if (response2.reasoning && typeof response2.reasoning === "string") {
797
+ controller.enqueue({
798
+ type: "reasoning",
799
+ textDelta: response2.reasoning
800
+ });
801
+ }
779
802
  controller.enqueue({
780
- type: "finish",
781
803
  finishReason: mapWorkersAIFinishReason(response2),
804
+ type: "finish",
782
805
  usage: response2.usage
783
806
  });
784
807
  controller.close();
785
808
  }
786
809
  }),
787
- rawCall: { rawPrompt: messages, rawSettings: args },
788
810
  warnings
789
811
  };
790
812
  }
@@ -796,8 +818,8 @@ var WorkersAIChatLanguageModel = class {
796
818
  const response = await this.config.binding.run(
797
819
  args.model,
798
820
  {
799
- messages,
800
821
  max_tokens: args.max_tokens,
822
+ messages,
801
823
  stream: true,
802
824
  temperature: args.temperature,
803
825
  tools: args.tools,
@@ -814,8 +836,8 @@ var WorkersAIChatLanguageModel = class {
814
836
  throw new Error("This shouldn't happen");
815
837
  }
816
838
  return {
817
- stream: getMappedStream(new Response(response)),
818
839
  rawCall: { rawPrompt: messages, rawSettings: args },
840
+ stream: getMappedStream(new Response(response)),
819
841
  warnings
820
842
  };
821
843
  }
@@ -848,19 +870,19 @@ var WorkersAIImageModel = class {
848
870
  const warnings = [];
849
871
  if (aspectRatio != null) {
850
872
  warnings.push({
851
- type: "unsupported-setting",
873
+ details: "This model does not support aspect ratio. Use `size` instead.",
852
874
  setting: "aspectRatio",
853
- details: "This model does not support aspect ratio. Use `size` instead."
875
+ type: "unsupported-setting"
854
876
  });
855
877
  }
856
878
  const generateImage = async () => {
857
879
  const outputStream = await this.config.binding.run(
858
880
  this.modelId,
859
881
  {
882
+ height,
860
883
  prompt,
861
884
  seed,
862
- width,
863
- height
885
+ width
864
886
  }
865
887
  );
866
888
  return streamToUint8Array(outputStream);
@@ -870,20 +892,20 @@ var WorkersAIImageModel = class {
870
892
  );
871
893
  return {
872
894
  images,
873
- warnings,
874
895
  response: {
875
- timestamp: /* @__PURE__ */ new Date(),
896
+ headers: {},
876
897
  modelId: this.modelId,
877
- headers: {}
878
- }
898
+ timestamp: /* @__PURE__ */ new Date()
899
+ },
900
+ warnings
879
901
  };
880
902
  }
881
903
  };
882
904
  function getDimensionsFromSizeString(size) {
883
905
  const [width, height] = size?.split("x") ?? [void 0, void 0];
884
906
  return {
885
- width: parseInteger(width),
886
- height: parseInteger(height)
907
+ height: parseInteger(height),
908
+ width: parseInteger(width)
887
909
  };
888
910
  }
889
911
  function parseInteger(value) {
@@ -925,19 +947,19 @@ function createWorkersAI(options) {
925
947
  throw new Error("Either a binding or credentials must be provided.");
926
948
  }
927
949
  const createChatModel = (modelId, settings = {}) => new WorkersAIChatLanguageModel(modelId, settings, {
928
- provider: "workersai.chat",
929
950
  binding,
930
- gateway: options.gateway
951
+ gateway: options.gateway,
952
+ provider: "workersai.chat"
931
953
  });
932
954
  const createImageModel = (modelId, settings = {}) => new WorkersAIImageModel(modelId, settings, {
933
- provider: "workersai.image",
934
955
  binding,
935
- gateway: options.gateway
956
+ gateway: options.gateway,
957
+ provider: "workersai.image"
936
958
  });
937
959
  const createEmbeddingModel = (modelId, settings = {}) => new WorkersAIEmbeddingModel(modelId, settings, {
938
- provider: "workersai.embedding",
939
960
  binding,
940
- gateway: options.gateway
961
+ gateway: options.gateway,
962
+ provider: "workersai.embedding"
941
963
  });
942
964
  const provider = (modelId, settings) => {
943
965
  if (new.target) {
@@ -958,8 +980,8 @@ function createAutoRAG(options) {
958
980
  const createChatModel = (settings = {}) => (
959
981
  // @ts-ignore Needs fix from @cloudflare/workers-types for custom types
960
982
  new AutoRAGChatLanguageModel("@cf/meta/llama-3.3-70b-instruct-fp8-fast", settings, {
961
- provider: "autorag.chat",
962
- binding
983
+ binding,
984
+ provider: "autorag.chat"
963
985
  })
964
986
  );
965
987
  const provider = (settings) => {