@assistant-ui/react 0.4.4 → 0.4.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/edge.mjs CHANGED
@@ -34,6 +34,15 @@ function assistantEncoderStream() {
34
34
  }
35
35
  case "tool-call":
36
36
  break;
37
+ case "tool-result": {
38
+ controller.enqueue(
39
+ formatStreamPart("3" /* ToolCallResult */, {
40
+ id: chunk.toolCallId,
41
+ result: chunk.result
42
+ })
43
+ );
44
+ break;
45
+ }
37
46
  case "finish": {
38
47
  const { type, ...rest } = chunk;
39
48
  controller.enqueue(
@@ -177,21 +186,658 @@ function toLanguageModelMessages(message) {
177
186
  }
178
187
 
179
188
  // src/runtimes/edge/createEdgeRuntimeAPI.ts
180
- var createEdgeRuntimeAPI = ({ model }) => {
189
+ import { z as z3 } from "zod";
190
+
191
+ // src/runtimes/edge/converters/toLanguageModelTools.ts
192
+ import { z } from "zod";
193
+ import zodToJsonSchema from "zod-to-json-schema";
194
+ var toLanguageModelTools = (tools) => {
195
+ if (!tools) return [];
196
+ return Object.entries(tools).map(([name, tool]) => ({
197
+ type: "function",
198
+ name,
199
+ ...tool.description ? { description: tool.description } : void 0,
200
+ parameters: tool.parameters instanceof z.ZodType ? zodToJsonSchema(tool.parameters) : tool.parameters
201
+ }));
202
+ };
203
+
204
+ // src/runtimes/edge/streams/toolResultStream.ts
205
+ import { z as z2 } from "zod";
206
+ import sjson from "secure-json-parse";
207
+ function toolResultStream(tools) {
208
+ const toolCallExecutions = /* @__PURE__ */ new Map();
209
+ return new TransformStream({
210
+ transform(chunk, controller) {
211
+ controller.enqueue(chunk);
212
+ const chunkType = chunk.type;
213
+ switch (chunkType) {
214
+ case "tool-call": {
215
+ const { toolCallId, toolCallType, toolName, args: argsText } = chunk;
216
+ const tool = tools?.[toolName];
217
+ if (!tool || !tool.execute) return;
218
+ const args = sjson.parse(argsText);
219
+ if (tool.parameters instanceof z2.ZodType) {
220
+ const result = tool.parameters.safeParse(args);
221
+ if (!result.success) {
222
+ controller.enqueue({
223
+ type: "error",
224
+ error: new Error("Invalid tool call arguments")
225
+ });
226
+ return;
227
+ } else {
228
+ toolCallExecutions.set(
229
+ toolCallId,
230
+ (async () => {
231
+ try {
232
+ const result2 = await tool.execute(args);
233
+ controller.enqueue({
234
+ type: "tool-result",
235
+ toolCallType,
236
+ toolCallId,
237
+ toolName,
238
+ result: result2
239
+ });
240
+ } catch (error) {
241
+ controller.enqueue({
242
+ type: "error",
243
+ error
244
+ });
245
+ } finally {
246
+ toolCallExecutions.delete(toolCallId);
247
+ }
248
+ })()
249
+ );
250
+ }
251
+ }
252
+ break;
253
+ }
254
+ case "text-delta":
255
+ case "tool-call-delta":
256
+ case "tool-result":
257
+ case "finish":
258
+ case "error":
259
+ break;
260
+ default: {
261
+ const unhandledType = chunkType;
262
+ throw new Error(`Unhandled chunk type: ${unhandledType}`);
263
+ }
264
+ }
265
+ },
266
+ async flush() {
267
+ await Promise.all(toolCallExecutions.values());
268
+ }
269
+ });
270
+ }
271
+
272
+ // src/runtimes/edge/partial-json/parse-partial-json.ts
273
+ import sjson2 from "secure-json-parse";
274
+
275
+ // src/runtimes/edge/partial-json/fix-json.ts
276
+ function fixJson(input) {
277
+ const stack = ["ROOT"];
278
+ let lastValidIndex = -1;
279
+ let literalStart = null;
280
+ function processValueStart(char, i, swapState) {
281
+ {
282
+ switch (char) {
283
+ case '"': {
284
+ lastValidIndex = i;
285
+ stack.pop();
286
+ stack.push(swapState);
287
+ stack.push("INSIDE_STRING");
288
+ break;
289
+ }
290
+ case "f":
291
+ case "t":
292
+ case "n": {
293
+ lastValidIndex = i;
294
+ literalStart = i;
295
+ stack.pop();
296
+ stack.push(swapState);
297
+ stack.push("INSIDE_LITERAL");
298
+ break;
299
+ }
300
+ case "-": {
301
+ stack.pop();
302
+ stack.push(swapState);
303
+ stack.push("INSIDE_NUMBER");
304
+ break;
305
+ }
306
+ case "0":
307
+ case "1":
308
+ case "2":
309
+ case "3":
310
+ case "4":
311
+ case "5":
312
+ case "6":
313
+ case "7":
314
+ case "8":
315
+ case "9": {
316
+ lastValidIndex = i;
317
+ stack.pop();
318
+ stack.push(swapState);
319
+ stack.push("INSIDE_NUMBER");
320
+ break;
321
+ }
322
+ case "{": {
323
+ lastValidIndex = i;
324
+ stack.pop();
325
+ stack.push(swapState);
326
+ stack.push("INSIDE_OBJECT_START");
327
+ break;
328
+ }
329
+ case "[": {
330
+ lastValidIndex = i;
331
+ stack.pop();
332
+ stack.push(swapState);
333
+ stack.push("INSIDE_ARRAY_START");
334
+ break;
335
+ }
336
+ }
337
+ }
338
+ }
339
+ function processAfterObjectValue(char, i) {
340
+ switch (char) {
341
+ case ",": {
342
+ stack.pop();
343
+ stack.push("INSIDE_OBJECT_AFTER_COMMA");
344
+ break;
345
+ }
346
+ case "}": {
347
+ lastValidIndex = i;
348
+ stack.pop();
349
+ break;
350
+ }
351
+ }
352
+ }
353
+ function processAfterArrayValue(char, i) {
354
+ switch (char) {
355
+ case ",": {
356
+ stack.pop();
357
+ stack.push("INSIDE_ARRAY_AFTER_COMMA");
358
+ break;
359
+ }
360
+ case "]": {
361
+ lastValidIndex = i;
362
+ stack.pop();
363
+ break;
364
+ }
365
+ }
366
+ }
367
+ for (let i = 0; i < input.length; i++) {
368
+ const char = input[i];
369
+ const currentState = stack[stack.length - 1];
370
+ switch (currentState) {
371
+ case "ROOT":
372
+ processValueStart(char, i, "FINISH");
373
+ break;
374
+ case "INSIDE_OBJECT_START": {
375
+ switch (char) {
376
+ case '"': {
377
+ stack.pop();
378
+ stack.push("INSIDE_OBJECT_KEY");
379
+ break;
380
+ }
381
+ case "}": {
382
+ lastValidIndex = i;
383
+ stack.pop();
384
+ break;
385
+ }
386
+ }
387
+ break;
388
+ }
389
+ case "INSIDE_OBJECT_AFTER_COMMA": {
390
+ switch (char) {
391
+ case '"': {
392
+ stack.pop();
393
+ stack.push("INSIDE_OBJECT_KEY");
394
+ break;
395
+ }
396
+ }
397
+ break;
398
+ }
399
+ case "INSIDE_OBJECT_KEY": {
400
+ switch (char) {
401
+ case '"': {
402
+ stack.pop();
403
+ stack.push("INSIDE_OBJECT_AFTER_KEY");
404
+ break;
405
+ }
406
+ }
407
+ break;
408
+ }
409
+ case "INSIDE_OBJECT_AFTER_KEY": {
410
+ switch (char) {
411
+ case ":": {
412
+ stack.pop();
413
+ stack.push("INSIDE_OBJECT_BEFORE_VALUE");
414
+ break;
415
+ }
416
+ }
417
+ break;
418
+ }
419
+ case "INSIDE_OBJECT_BEFORE_VALUE": {
420
+ processValueStart(char, i, "INSIDE_OBJECT_AFTER_VALUE");
421
+ break;
422
+ }
423
+ case "INSIDE_OBJECT_AFTER_VALUE": {
424
+ processAfterObjectValue(char, i);
425
+ break;
426
+ }
427
+ case "INSIDE_STRING": {
428
+ switch (char) {
429
+ case '"': {
430
+ stack.pop();
431
+ lastValidIndex = i;
432
+ break;
433
+ }
434
+ case "\\": {
435
+ stack.push("INSIDE_STRING_ESCAPE");
436
+ break;
437
+ }
438
+ default: {
439
+ lastValidIndex = i;
440
+ }
441
+ }
442
+ break;
443
+ }
444
+ case "INSIDE_ARRAY_START": {
445
+ switch (char) {
446
+ case "]": {
447
+ lastValidIndex = i;
448
+ stack.pop();
449
+ break;
450
+ }
451
+ default: {
452
+ lastValidIndex = i;
453
+ processValueStart(char, i, "INSIDE_ARRAY_AFTER_VALUE");
454
+ break;
455
+ }
456
+ }
457
+ break;
458
+ }
459
+ case "INSIDE_ARRAY_AFTER_VALUE": {
460
+ switch (char) {
461
+ case ",": {
462
+ stack.pop();
463
+ stack.push("INSIDE_ARRAY_AFTER_COMMA");
464
+ break;
465
+ }
466
+ case "]": {
467
+ lastValidIndex = i;
468
+ stack.pop();
469
+ break;
470
+ }
471
+ default: {
472
+ lastValidIndex = i;
473
+ break;
474
+ }
475
+ }
476
+ break;
477
+ }
478
+ case "INSIDE_ARRAY_AFTER_COMMA": {
479
+ processValueStart(char, i, "INSIDE_ARRAY_AFTER_VALUE");
480
+ break;
481
+ }
482
+ case "INSIDE_STRING_ESCAPE": {
483
+ stack.pop();
484
+ lastValidIndex = i;
485
+ break;
486
+ }
487
+ case "INSIDE_NUMBER": {
488
+ switch (char) {
489
+ case "0":
490
+ case "1":
491
+ case "2":
492
+ case "3":
493
+ case "4":
494
+ case "5":
495
+ case "6":
496
+ case "7":
497
+ case "8":
498
+ case "9": {
499
+ lastValidIndex = i;
500
+ break;
501
+ }
502
+ case "e":
503
+ case "E":
504
+ case "-":
505
+ case ".": {
506
+ break;
507
+ }
508
+ case ",": {
509
+ stack.pop();
510
+ if (stack[stack.length - 1] === "INSIDE_ARRAY_AFTER_VALUE") {
511
+ processAfterArrayValue(char, i);
512
+ }
513
+ if (stack[stack.length - 1] === "INSIDE_OBJECT_AFTER_VALUE") {
514
+ processAfterObjectValue(char, i);
515
+ }
516
+ break;
517
+ }
518
+ case "}": {
519
+ stack.pop();
520
+ if (stack[stack.length - 1] === "INSIDE_OBJECT_AFTER_VALUE") {
521
+ processAfterObjectValue(char, i);
522
+ }
523
+ break;
524
+ }
525
+ case "]": {
526
+ stack.pop();
527
+ if (stack[stack.length - 1] === "INSIDE_ARRAY_AFTER_VALUE") {
528
+ processAfterArrayValue(char, i);
529
+ }
530
+ break;
531
+ }
532
+ default: {
533
+ stack.pop();
534
+ break;
535
+ }
536
+ }
537
+ break;
538
+ }
539
+ case "INSIDE_LITERAL": {
540
+ const partialLiteral = input.substring(literalStart, i + 1);
541
+ if (!"false".startsWith(partialLiteral) && !"true".startsWith(partialLiteral) && !"null".startsWith(partialLiteral)) {
542
+ stack.pop();
543
+ if (stack[stack.length - 1] === "INSIDE_OBJECT_AFTER_VALUE") {
544
+ processAfterObjectValue(char, i);
545
+ } else if (stack[stack.length - 1] === "INSIDE_ARRAY_AFTER_VALUE") {
546
+ processAfterArrayValue(char, i);
547
+ }
548
+ } else {
549
+ lastValidIndex = i;
550
+ }
551
+ break;
552
+ }
553
+ }
554
+ }
555
+ let result = input.slice(0, lastValidIndex + 1);
556
+ for (let i = stack.length - 1; i >= 0; i--) {
557
+ const state = stack[i];
558
+ switch (state) {
559
+ case "INSIDE_STRING": {
560
+ result += '"';
561
+ break;
562
+ }
563
+ case "INSIDE_OBJECT_KEY":
564
+ case "INSIDE_OBJECT_AFTER_KEY":
565
+ case "INSIDE_OBJECT_AFTER_COMMA":
566
+ case "INSIDE_OBJECT_START":
567
+ case "INSIDE_OBJECT_BEFORE_VALUE":
568
+ case "INSIDE_OBJECT_AFTER_VALUE": {
569
+ result += "}";
570
+ break;
571
+ }
572
+ case "INSIDE_ARRAY_START":
573
+ case "INSIDE_ARRAY_AFTER_COMMA":
574
+ case "INSIDE_ARRAY_AFTER_VALUE": {
575
+ result += "]";
576
+ break;
577
+ }
578
+ case "INSIDE_LITERAL": {
579
+ const partialLiteral = input.substring(literalStart, input.length);
580
+ if ("true".startsWith(partialLiteral)) {
581
+ result += "true".slice(partialLiteral.length);
582
+ } else if ("false".startsWith(partialLiteral)) {
583
+ result += "false".slice(partialLiteral.length);
584
+ } else if ("null".startsWith(partialLiteral)) {
585
+ result += "null".slice(partialLiteral.length);
586
+ }
587
+ }
588
+ }
589
+ }
590
+ return result;
591
+ }
592
+
593
+ // src/runtimes/edge/partial-json/parse-partial-json.ts
594
+ var parsePartialJson = (json) => {
595
+ try {
596
+ return sjson2.parse(json);
597
+ } catch {
598
+ try {
599
+ return sjson2.parse(fixJson(json));
600
+ } catch {
601
+ return void 0;
602
+ }
603
+ }
604
+ };
605
+
606
+ // src/runtimes/edge/streams/runResultStream.ts
607
+ function runResultStream(initialContent) {
608
+ let message = {
609
+ content: initialContent
610
+ };
611
+ const currentToolCall = { toolCallId: "", argsText: "" };
612
+ return new TransformStream({
613
+ transform(chunk, controller) {
614
+ const chunkType = chunk.type;
615
+ switch (chunkType) {
616
+ case "text-delta": {
617
+ message = appendOrUpdateText(message, chunk.textDelta);
618
+ controller.enqueue(message);
619
+ break;
620
+ }
621
+ case "tool-call-delta": {
622
+ const { toolCallId, toolName, argsTextDelta } = chunk;
623
+ if (currentToolCall.toolCallId !== toolCallId) {
624
+ currentToolCall.toolCallId = toolCallId;
625
+ currentToolCall.argsText = argsTextDelta;
626
+ } else {
627
+ currentToolCall.argsText += argsTextDelta;
628
+ }
629
+ message = appendOrUpdateToolCall(
630
+ message,
631
+ toolCallId,
632
+ toolName,
633
+ currentToolCall.argsText
634
+ );
635
+ controller.enqueue(message);
636
+ break;
637
+ }
638
+ case "tool-call": {
639
+ break;
640
+ }
641
+ case "tool-result": {
642
+ message = appendOrUpdateToolResult(
643
+ message,
644
+ chunk.toolCallId,
645
+ chunk.toolName,
646
+ chunk.result
647
+ );
648
+ controller.enqueue(message);
649
+ break;
650
+ }
651
+ case "finish": {
652
+ message = appendOrUpdateFinish(message, chunk);
653
+ controller.enqueue(message);
654
+ break;
655
+ }
656
+ case "error": {
657
+ throw chunk.error;
658
+ }
659
+ default: {
660
+ const unhandledType = chunkType;
661
+ throw new Error(`Unhandled chunk type: ${unhandledType}`);
662
+ }
663
+ }
664
+ }
665
+ });
666
+ }
667
+ var appendOrUpdateText = (message, textDelta) => {
668
+ let contentParts = message.content;
669
+ let contentPart = message.content.at(-1);
670
+ if (contentPart?.type !== "text") {
671
+ contentPart = { type: "text", text: textDelta };
672
+ } else {
673
+ contentParts = contentParts.slice(0, -1);
674
+ contentPart = { type: "text", text: contentPart.text + textDelta };
675
+ }
676
+ return {
677
+ ...message,
678
+ content: contentParts.concat([contentPart])
679
+ };
680
+ };
681
+ var appendOrUpdateToolCall = (message, toolCallId, toolName, argsText) => {
682
+ let contentParts = message.content;
683
+ let contentPart = message.content.at(-1);
684
+ if (contentPart?.type !== "tool-call" || contentPart.toolCallId !== toolCallId) {
685
+ contentPart = {
686
+ type: "tool-call",
687
+ toolCallId,
688
+ toolName,
689
+ argsText,
690
+ args: parsePartialJson(argsText)
691
+ };
692
+ } else {
693
+ contentParts = contentParts.slice(0, -1);
694
+ contentPart = {
695
+ ...contentPart,
696
+ argsText,
697
+ args: parsePartialJson(argsText)
698
+ };
699
+ }
700
+ return {
701
+ ...message,
702
+ content: contentParts.concat([contentPart])
703
+ };
704
+ };
705
+ var appendOrUpdateToolResult = (message, toolCallId, toolName, result) => {
706
+ let found = false;
707
+ const newContentParts = message.content.map((part) => {
708
+ if (part.type !== "tool-call" || part.toolCallId !== toolCallId)
709
+ return part;
710
+ found = true;
711
+ if (part.toolName !== toolName)
712
+ throw new Error(
713
+ `Tool call ${toolCallId} found with tool name ${part.toolName}, but expected ${toolName}`
714
+ );
715
+ return {
716
+ ...part,
717
+ result
718
+ };
719
+ });
720
+ if (!found)
721
+ throw new Error(
722
+ `Received tool result for unknown tool call "${toolName}" / "${toolCallId}". This is likely an internal bug in assistant-ui.`
723
+ );
724
+ return {
725
+ ...message,
726
+ content: newContentParts
727
+ };
728
+ };
729
+ var appendOrUpdateFinish = (message, chunk) => {
730
+ const { type, ...rest } = chunk;
731
+ return {
732
+ ...message,
733
+ status: {
734
+ type: "done",
735
+ ...rest
736
+ }
737
+ };
738
+ };
739
+
740
+ // src/runtimes/edge/createEdgeRuntimeAPI.ts
741
+ var LanguageModelSettingsSchema = z3.object({
742
+ maxTokens: z3.number().int().positive().optional(),
743
+ temperature: z3.number().optional(),
744
+ topP: z3.number().optional(),
745
+ presencePenalty: z3.number().optional(),
746
+ frequencyPenalty: z3.number().optional(),
747
+ seed: z3.number().int().optional(),
748
+ headers: z3.record(z3.string().optional()).optional()
749
+ });
750
+ var voidStream = () => {
751
+ return new WritableStream({
752
+ abort(reason) {
753
+ console.error("Server stream processing aborted:", reason);
754
+ }
755
+ });
756
+ };
757
+ var createEdgeRuntimeAPI = ({
758
+ model,
759
+ system: serverSystem,
760
+ tools: serverTools = {},
761
+ toolChoice,
762
+ onFinish,
763
+ ...unsafeSettings
764
+ }) => {
765
+ const settings = LanguageModelSettingsSchema.parse(unsafeSettings);
766
+ const lmServerTools = toLanguageModelTools(serverTools);
767
+ const hasServerTools = Object.values(serverTools).some((v) => !!v.execute);
181
768
  const POST = async (request) => {
182
- const { system, messages, tools } = await request.json();
183
- const { stream } = await streamMessage({
769
+ const {
770
+ system: clientSystem,
771
+ tools: clientTools,
772
+ messages
773
+ } = await request.json();
774
+ const systemMessages = [];
775
+ if (serverSystem) systemMessages.push(serverSystem);
776
+ if (clientSystem) systemMessages.push(clientSystem);
777
+ const system = systemMessages.join("\n\n");
778
+ for (const clientTool of clientTools) {
779
+ if (serverTools?.[clientTool.name]) {
780
+ throw new Error(
781
+ `Tool ${clientTool.name} was defined in both the client and server tools. This is not allowed.`
782
+ );
783
+ }
784
+ }
785
+ let stream;
786
+ const streamResult = await streamMessage({
787
+ ...settings,
184
788
  model,
185
789
  abortSignal: request.signal,
186
- ...system ? { system } : void 0,
790
+ ...!!system ? { system } : void 0,
187
791
  messages,
188
- tools
792
+ tools: lmServerTools.concat(clientTools),
793
+ ...toolChoice ? { toolChoice } : void 0
189
794
  });
190
- return new Response(stream, {
191
- headers: {
192
- contentType: "text/plain; charset=utf-8"
795
+ stream = streamResult.stream;
796
+ const canExecuteTools = hasServerTools && toolChoice?.type !== "none";
797
+ if (canExecuteTools) {
798
+ stream = stream.pipeThrough(toolResultStream(serverTools));
799
+ }
800
+ if (canExecuteTools || onFinish) {
801
+ const tees = stream.tee();
802
+ stream = tees[0];
803
+ let serverStream = tees[1];
804
+ if (onFinish) {
805
+ serverStream = serverStream.pipeThrough(runResultStream([])).pipeThrough(
806
+ new TransformStream({
807
+ transform(chunk) {
808
+ if (chunk.status?.type !== "done") return;
809
+ const resultingMessages = [
810
+ ...messages,
811
+ {
812
+ role: "assistant",
813
+ content: chunk.content
814
+ }
815
+ ];
816
+ onFinish({
817
+ finishReason: chunk.status.finishReason,
818
+ usage: chunk.status.usage,
819
+ messages: resultingMessages,
820
+ logProbs: chunk.status.logprops,
821
+ warnings: streamResult.warnings,
822
+ rawCall: streamResult.rawCall,
823
+ rawResponse: streamResult.rawResponse
824
+ });
825
+ }
826
+ })
827
+ );
193
828
  }
194
- });
829
+ serverStream.pipeTo(voidStream()).catch((e) => {
830
+ console.error("Server stream processing error:", e);
831
+ });
832
+ }
833
+ return new Response(
834
+ stream.pipeThrough(assistantEncoderStream()).pipeThrough(new TextEncoderStream()),
835
+ {
836
+ headers: {
837
+ contentType: "text/plain; charset=utf-8"
838
+ }
839
+ }
840
+ );
195
841
  };
196
842
  return { POST };
197
843
  };
@@ -203,7 +849,7 @@ async function streamMessage({
203
849
  toolChoice,
204
850
  ...options
205
851
  }) {
206
- const { stream, warnings, rawResponse } = await model.doStream({
852
+ return model.doStream({
207
853
  inputFormat: "messages",
208
854
  mode: {
209
855
  type: "regular",
@@ -213,11 +859,6 @@ async function streamMessage({
213
859
  prompt: convertToLanguageModelPrompt(system, messages),
214
860
  ...options
215
861
  });
216
- return {
217
- stream: stream.pipeThrough(assistantEncoderStream()).pipeThrough(new TextEncoderStream()),
218
- warnings,
219
- rawResponse
220
- };
221
862
  }
222
863
  function convertToLanguageModelPrompt(system, messages) {
223
864
  const languageModelMessages = [];