@assistant-ui/react 0.5.34 → 0.5.36

Sign up to get free protection for your applications and to get access to all the features.
@@ -165,6 +165,74 @@ var toLanguageModelTools = (tools) => {
165
165
  }));
166
166
  };
167
167
 
168
+ // src/runtimes/edge/streams/assistantEncoderStream.ts
169
+ function assistantEncoderStream() {
170
+ const toolCalls = /* @__PURE__ */ new Set();
171
+ return new TransformStream({
172
+ transform(chunk, controller) {
173
+ const chunkType = chunk.type;
174
+ switch (chunkType) {
175
+ case "text-delta": {
176
+ controller.enqueue({
177
+ type: "0" /* TextDelta */,
178
+ value: chunk.textDelta
179
+ });
180
+ break;
181
+ }
182
+ case "tool-call-delta": {
183
+ if (!toolCalls.has(chunk.toolCallId)) {
184
+ toolCalls.add(chunk.toolCallId);
185
+ controller.enqueue({
186
+ type: "1" /* ToolCallBegin */,
187
+ value: {
188
+ id: chunk.toolCallId,
189
+ name: chunk.toolName
190
+ }
191
+ });
192
+ }
193
+ controller.enqueue({
194
+ type: "2" /* ToolCallArgsTextDelta */,
195
+ value: chunk.argsTextDelta
196
+ });
197
+ break;
198
+ }
199
+ // ignore
200
+ case "tool-call":
201
+ break;
202
+ case "tool-result": {
203
+ controller.enqueue({
204
+ type: "3" /* ToolCallResult */,
205
+ value: {
206
+ id: chunk.toolCallId,
207
+ result: chunk.result
208
+ }
209
+ });
210
+ break;
211
+ }
212
+ case "finish": {
213
+ const { type, ...rest } = chunk;
214
+ controller.enqueue({
215
+ type: "F" /* Finish */,
216
+ value: rest
217
+ });
218
+ break;
219
+ }
220
+ case "error": {
221
+ controller.enqueue({
222
+ type: "E" /* Error */,
223
+ value: chunk.error
224
+ });
225
+ break;
226
+ }
227
+ default: {
228
+ const unhandledType = chunkType;
229
+ throw new Error(`Unhandled chunk type: ${unhandledType}`);
230
+ }
231
+ }
232
+ }
233
+ });
234
+ }
235
+
168
236
  // src/types/ModelConfigTypes.ts
169
237
  import { z as z2 } from "zod";
170
238
  var LanguageModelV1CallSettingsSchema = z2.object({
@@ -220,39 +288,145 @@ ${config.system}`;
220
288
  }, {});
221
289
  };
222
290
 
223
- // src/runtimes/edge/streams/utils/PipeableTransformStream.ts
224
- var PipeableTransformStream = class extends TransformStream {
225
- constructor(transform) {
226
- super();
227
- const readable = transform(super.readable);
228
- Object.defineProperty(this, "readable", {
229
- value: readable,
230
- writable: false
231
- });
232
- }
233
- };
291
+ // src/runtimes/edge/EdgeRuntimeRequestOptions.ts
292
+ import { z as z3 } from "zod";
293
+ var LanguageModelV1FunctionToolSchema = z3.object({
294
+ type: z3.literal("function"),
295
+ name: z3.string(),
296
+ description: z3.string().optional(),
297
+ parameters: z3.custom(
298
+ (val) => typeof val === "object" && val !== null
299
+ )
300
+ });
301
+ var TextContentPartSchema = z3.object({
302
+ type: z3.literal("text"),
303
+ text: z3.string()
304
+ });
305
+ var ImageContentPartSchema = z3.object({
306
+ type: z3.literal("image"),
307
+ image: z3.string()
308
+ });
309
+ var CoreToolCallContentPartSchema = z3.object({
310
+ type: z3.literal("tool-call"),
311
+ toolCallId: z3.string(),
312
+ toolName: z3.string(),
313
+ args: z3.record(z3.unknown()),
314
+ result: z3.unknown().optional(),
315
+ isError: z3.boolean().optional()
316
+ });
317
+ var CoreUserMessageSchema = z3.object({
318
+ role: z3.literal("user"),
319
+ content: z3.array(
320
+ z3.discriminatedUnion("type", [
321
+ TextContentPartSchema,
322
+ ImageContentPartSchema
323
+ ])
324
+ ).min(1)
325
+ });
326
+ var CoreAssistantMessageSchema = z3.object({
327
+ role: z3.literal("assistant"),
328
+ content: z3.array(
329
+ z3.discriminatedUnion("type", [
330
+ TextContentPartSchema,
331
+ CoreToolCallContentPartSchema
332
+ ])
333
+ ).min(1)
334
+ });
335
+ var CoreSystemMessageSchema = z3.object({
336
+ role: z3.literal("system"),
337
+ content: z3.tuple([TextContentPartSchema])
338
+ });
339
+ var CoreMessageSchema = z3.discriminatedUnion("role", [
340
+ CoreSystemMessageSchema,
341
+ CoreUserMessageSchema,
342
+ CoreAssistantMessageSchema
343
+ ]);
344
+ var EdgeRuntimeRequestOptionsSchema = z3.object({
345
+ system: z3.string().optional(),
346
+ messages: z3.array(CoreMessageSchema).min(1),
347
+ tools: z3.array(LanguageModelV1FunctionToolSchema).optional()
348
+ }).merge(LanguageModelV1CallSettingsSchema).merge(LanguageModelConfigSchema);
234
349
 
235
- // src/runtimes/edge/streams/utils/streamPartEncoderStream.ts
236
- function encodeStreamPart({
237
- type,
238
- value
239
- }) {
240
- return `${type}:${JSON.stringify(value)}
241
- `;
242
- }
243
- function streamPartEncoderStream() {
244
- const encodeStream = new TransformStream({
350
+ // src/runtimes/edge/streams/toolResultStream.ts
351
+ import { z as z4 } from "zod";
352
+ import sjson from "secure-json-parse";
353
+ function toolResultStream(tools, abortSignal) {
354
+ const toolCallExecutions = /* @__PURE__ */ new Map();
355
+ return new TransformStream({
245
356
  transform(chunk, controller) {
246
- controller.enqueue(encodeStreamPart(chunk));
357
+ controller.enqueue(chunk);
358
+ const chunkType = chunk.type;
359
+ switch (chunkType) {
360
+ case "tool-call": {
361
+ const { toolCallId, toolCallType, toolName, args: argsText } = chunk;
362
+ const tool = tools?.[toolName];
363
+ if (!tool || !tool.execute) return;
364
+ const args = sjson.parse(argsText);
365
+ if (tool.parameters instanceof z4.ZodType) {
366
+ const result = tool.parameters.safeParse(args);
367
+ if (!result.success) {
368
+ controller.enqueue({
369
+ type: "tool-result",
370
+ toolCallType,
371
+ toolCallId,
372
+ toolName,
373
+ result: "Function parameter validation failed. " + JSON.stringify(result.error.issues),
374
+ isError: true
375
+ });
376
+ return;
377
+ } else {
378
+ toolCallExecutions.set(
379
+ toolCallId,
380
+ (async () => {
381
+ if (!tool.execute) return;
382
+ try {
383
+ const result2 = await tool.execute(args, { abortSignal });
384
+ controller.enqueue({
385
+ type: "tool-result",
386
+ toolCallType,
387
+ toolCallId,
388
+ toolName,
389
+ result: result2
390
+ });
391
+ } catch (error) {
392
+ controller.enqueue({
393
+ type: "tool-result",
394
+ toolCallType,
395
+ toolCallId,
396
+ toolName,
397
+ result: "Error: " + error,
398
+ isError: true
399
+ });
400
+ } finally {
401
+ toolCallExecutions.delete(toolCallId);
402
+ }
403
+ })()
404
+ );
405
+ }
406
+ }
407
+ break;
408
+ }
409
+ // ignore other parts
410
+ case "text-delta":
411
+ case "tool-call-delta":
412
+ case "tool-result":
413
+ case "finish":
414
+ case "error":
415
+ break;
416
+ default: {
417
+ const unhandledType = chunkType;
418
+ throw new Error(`Unhandled chunk type: ${unhandledType}`);
419
+ }
420
+ }
421
+ },
422
+ async flush() {
423
+ await Promise.all(toolCallExecutions.values());
247
424
  }
248
425
  });
249
- return new PipeableTransformStream((readable) => {
250
- return readable.pipeThrough(encodeStream).pipeThrough(new TextEncoderStream());
251
- });
252
426
  }
253
427
 
254
428
  // src/runtimes/edge/partial-json/parse-partial-json.ts
255
- import sjson from "secure-json-parse";
429
+ import sjson2 from "secure-json-parse";
256
430
 
257
431
  // src/runtimes/edge/partial-json/fix-json.ts
258
432
  function fixJson(input) {
@@ -575,10 +749,10 @@ function fixJson(input) {
575
749
  // src/runtimes/edge/partial-json/parse-partial-json.ts
576
750
  var parsePartialJson = (json) => {
577
751
  try {
578
- return sjson.parse(json);
752
+ return sjson2.parse(json);
579
753
  } catch {
580
754
  try {
581
- return sjson.parse(fixJson(json));
755
+ return sjson2.parse(fixJson(json));
582
756
  } catch {
583
757
  return void 0;
584
758
  }
@@ -654,8 +828,8 @@ function runResultStream() {
654
828
  });
655
829
  }
656
830
  var appendOrUpdateText = (message, textDelta) => {
657
- let contentParts = message.content;
658
- let contentPart = message.content.at(-1);
831
+ let contentParts = message.content ?? [];
832
+ let contentPart = message.content?.at(-1);
659
833
  if (contentPart?.type !== "text") {
660
834
  contentPart = { type: "text", text: textDelta };
661
835
  } else {
@@ -668,8 +842,8 @@ var appendOrUpdateText = (message, textDelta) => {
668
842
  };
669
843
  };
670
844
  var appendOrUpdateToolCall = (message, toolCallId, toolName, argsText) => {
671
- let contentParts = message.content;
672
- let contentPart = message.content.at(-1);
845
+ let contentParts = message.content ?? [];
846
+ let contentPart = message.content?.at(-1);
673
847
  if (contentPart?.type !== "tool-call" || contentPart.toolCallId !== toolCallId) {
674
848
  contentPart = {
675
849
  type: "tool-call",
@@ -693,7 +867,7 @@ var appendOrUpdateToolCall = (message, toolCallId, toolName, argsText) => {
693
867
  };
694
868
  var appendOrUpdateToolResult = (message, toolCallId, toolName, result) => {
695
869
  let found = false;
696
- const newContentParts = message.content.map((part) => {
870
+ const newContentParts = message.content?.map((part) => {
697
871
  if (part.type !== "tool-call" || part.toolCallId !== toolCallId)
698
872
  return part;
699
873
  found = true;
@@ -760,87 +934,182 @@ var appendOrUpdateCancel = (message) => {
760
934
  };
761
935
  };
762
936
 
763
- // src/runtimes/edge/streams/toolResultStream.ts
764
- import { z as z3 } from "zod";
765
- import sjson2 from "secure-json-parse";
766
- function toolResultStream(tools, abortSignal) {
767
- const toolCallExecutions = /* @__PURE__ */ new Map();
768
- return new TransformStream({
937
+ // src/runtimes/edge/streams/utils/PipeableTransformStream.ts
938
+ var PipeableTransformStream = class extends TransformStream {
939
+ constructor(transform) {
940
+ super();
941
+ const readable = transform(super.readable);
942
+ Object.defineProperty(this, "readable", {
943
+ value: readable,
944
+ writable: false
945
+ });
946
+ }
947
+ };
948
+
949
+ // src/runtimes/edge/streams/utils/streamPartEncoderStream.ts
950
+ function encodeStreamPart({
951
+ type,
952
+ value
953
+ }) {
954
+ return `${type}:${JSON.stringify(value)}
955
+ `;
956
+ }
957
+ function streamPartEncoderStream() {
958
+ const encodeStream = new TransformStream({
769
959
  transform(chunk, controller) {
770
- controller.enqueue(chunk);
771
- const chunkType = chunk.type;
772
- switch (chunkType) {
773
- case "tool-call": {
774
- const { toolCallId, toolCallType, toolName, args: argsText } = chunk;
775
- const tool = tools?.[toolName];
776
- if (!tool || !tool.execute) return;
777
- const args = sjson2.parse(argsText);
778
- if (tool.parameters instanceof z3.ZodType) {
779
- const result = tool.parameters.safeParse(args);
780
- if (!result.success) {
781
- controller.enqueue({
782
- type: "tool-result",
783
- toolCallType,
784
- toolCallId,
785
- toolName,
786
- result: "Function parameter validation failed. " + JSON.stringify(result.error.issues),
787
- isError: true
788
- });
960
+ controller.enqueue(encodeStreamPart(chunk));
961
+ }
962
+ });
963
+ return new PipeableTransformStream((readable) => {
964
+ return readable.pipeThrough(encodeStream).pipeThrough(new TextEncoderStream());
965
+ });
966
+ }
967
+
968
+ // src/runtimes/edge/createEdgeRuntimeAPI.ts
969
+ var voidStream = () => {
970
+ return new WritableStream({
971
+ abort(reason) {
972
+ console.error("Server stream processing aborted:", reason);
973
+ }
974
+ });
975
+ };
976
+ var getEdgeRuntimeStream = async ({
977
+ abortSignal,
978
+ requestData: unsafeRequest,
979
+ options: {
980
+ model: modelOrCreator,
981
+ system: serverSystem,
982
+ tools: serverTools = {},
983
+ toolChoice,
984
+ onFinish,
985
+ ...unsafeSettings
986
+ }
987
+ }) => {
988
+ const settings = LanguageModelV1CallSettingsSchema.parse(unsafeSettings);
989
+ const lmServerTools = toLanguageModelTools(serverTools);
990
+ const hasServerTools = Object.values(serverTools).some((v) => !!v.execute);
991
+ const {
992
+ system: clientSystem,
993
+ tools: clientTools = [],
994
+ messages,
995
+ apiKey,
996
+ baseUrl,
997
+ modelName,
998
+ ...callSettings
999
+ } = EdgeRuntimeRequestOptionsSchema.parse(unsafeRequest);
1000
+ const systemMessages = [];
1001
+ if (serverSystem) systemMessages.push(serverSystem);
1002
+ if (clientSystem) systemMessages.push(clientSystem);
1003
+ const system = systemMessages.join("\n\n");
1004
+ for (const clientTool of clientTools) {
1005
+ if (serverTools?.[clientTool.name]) {
1006
+ throw new Error(
1007
+ `Tool ${clientTool.name} was defined in both the client and server tools. This is not allowed.`
1008
+ );
1009
+ }
1010
+ }
1011
+ let model;
1012
+ model = typeof modelOrCreator === "function" ? await modelOrCreator({ apiKey, baseUrl, modelName }) : modelOrCreator;
1013
+ let stream;
1014
+ const streamResult = await streamMessage({
1015
+ ...settings,
1016
+ ...callSettings,
1017
+ model,
1018
+ abortSignal,
1019
+ ...!!system ? { system } : void 0,
1020
+ messages,
1021
+ tools: lmServerTools.concat(clientTools),
1022
+ ...toolChoice ? { toolChoice } : void 0
1023
+ });
1024
+ stream = streamResult.stream;
1025
+ const canExecuteTools = hasServerTools && toolChoice?.type !== "none";
1026
+ if (canExecuteTools) {
1027
+ stream = stream.pipeThrough(toolResultStream(serverTools, abortSignal));
1028
+ }
1029
+ if (canExecuteTools || onFinish) {
1030
+ const tees = stream.tee();
1031
+ stream = tees[0];
1032
+ let serverStream = tees[1];
1033
+ if (onFinish) {
1034
+ let lastChunk;
1035
+ serverStream = serverStream.pipeThrough(runResultStream()).pipeThrough(
1036
+ new TransformStream({
1037
+ transform(chunk) {
1038
+ lastChunk = chunk;
1039
+ },
1040
+ flush() {
1041
+ if (!lastChunk?.status || lastChunk.status.type === "running")
789
1042
  return;
790
- } else {
791
- toolCallExecutions.set(
792
- toolCallId,
793
- (async () => {
794
- if (!tool.execute) return;
795
- try {
796
- const result2 = await tool.execute(args, { abortSignal });
797
- controller.enqueue({
798
- type: "tool-result",
799
- toolCallType,
800
- toolCallId,
801
- toolName,
802
- result: result2
803
- });
804
- } catch (error) {
805
- controller.enqueue({
806
- type: "tool-result",
807
- toolCallType,
808
- toolCallId,
809
- toolName,
810
- result: "Error: " + error,
811
- isError: true
812
- });
813
- } finally {
814
- toolCallExecutions.delete(toolCallId);
815
- }
816
- })()
817
- );
818
- }
1043
+ const resultingMessages = [
1044
+ ...messages,
1045
+ toCoreMessage({
1046
+ role: "assistant",
1047
+ content: lastChunk.content
1048
+ })
1049
+ ];
1050
+ onFinish({
1051
+ messages: resultingMessages,
1052
+ metadata: {
1053
+ roundtrips: lastChunk.metadata?.roundtrips
1054
+ }
1055
+ });
819
1056
  }
820
- break;
821
- }
822
- // ignore other parts
823
- case "text-delta":
824
- case "tool-call-delta":
825
- case "tool-result":
826
- case "finish":
827
- case "error":
828
- break;
829
- default: {
830
- const unhandledType = chunkType;
831
- throw new Error(`Unhandled chunk type: ${unhandledType}`);
832
- }
1057
+ })
1058
+ );
1059
+ }
1060
+ serverStream.pipeTo(voidStream()).catch((e) => {
1061
+ console.error("Server stream processing error:", e);
1062
+ });
1063
+ }
1064
+ return stream;
1065
+ };
1066
+ var getEdgeRuntimeResponse = async (options) => {
1067
+ const stream = await getEdgeRuntimeStream(options);
1068
+ return new Response(
1069
+ stream.pipeThrough(assistantEncoderStream()).pipeThrough(streamPartEncoderStream()),
1070
+ {
1071
+ headers: {
1072
+ "Content-Type": "text/plain; charset=utf-8"
833
1073
  }
834
- },
835
- async flush() {
836
- await Promise.all(toolCallExecutions.values());
837
1074
  }
1075
+ );
1076
+ };
1077
+ var createEdgeRuntimeAPI = (options) => ({
1078
+ POST: async (request) => getEdgeRuntimeResponse({
1079
+ abortSignal: request.signal,
1080
+ requestData: await request.json(),
1081
+ options
1082
+ })
1083
+ });
1084
+ async function streamMessage({
1085
+ model,
1086
+ system,
1087
+ messages,
1088
+ tools,
1089
+ toolChoice,
1090
+ ...options
1091
+ }) {
1092
+ return model.doStream({
1093
+ inputFormat: "messages",
1094
+ mode: {
1095
+ type: "regular",
1096
+ ...tools ? { tools } : void 0,
1097
+ ...toolChoice ? { toolChoice } : void 0
1098
+ },
1099
+ prompt: convertToLanguageModelPrompt(system, messages),
1100
+ ...options
838
1101
  });
839
1102
  }
1103
+ function convertToLanguageModelPrompt(system, messages) {
1104
+ const languageModelMessages = [];
1105
+ if (system != null) {
1106
+ languageModelMessages.push({ role: "system", content: system });
1107
+ }
1108
+ languageModelMessages.push(...toLanguageModelMessages(messages));
1109
+ return languageModelMessages;
1110
+ }
840
1111
 
841
1112
  export {
842
- LanguageModelV1CallSettingsSchema,
843
- LanguageModelConfigSchema,
844
1113
  mergeModelConfigs,
845
1114
  toLanguageModelMessages,
846
1115
  toCoreMessages,
@@ -849,6 +1118,9 @@ export {
849
1118
  PipeableTransformStream,
850
1119
  streamPartEncoderStream,
851
1120
  runResultStream,
852
- toolResultStream
1121
+ toolResultStream,
1122
+ getEdgeRuntimeStream,
1123
+ getEdgeRuntimeResponse,
1124
+ createEdgeRuntimeAPI
853
1125
  };
854
- //# sourceMappingURL=chunk-QBSJKWMT.mjs.map
1126
+ //# sourceMappingURL=chunk-5ZTUOAPH.mjs.map