@assistant-ui/react 0.4.4 → 0.4.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.mjs CHANGED
@@ -1902,7 +1902,7 @@ var ThreadPrimitiveSuggestion = createActionButton(
1902
1902
  );
1903
1903
 
1904
1904
  // src/runtimes/local/useLocalRuntime.tsx
1905
- import { useInsertionEffect as useInsertionEffect3, useState as useState7 } from "react";
1905
+ import { useInsertionEffect as useInsertionEffect3, useState as useState8 } from "react";
1906
1906
 
1907
1907
  // src/utils/idUtils.tsx
1908
1908
  import { customAlphabet } from "nanoid/non-secure";
@@ -2205,163 +2205,8 @@ var TooltipIconButton = forwardRef17(({ children, tooltip, side = "bottom", ...r
2205
2205
  });
2206
2206
  TooltipIconButton.displayName = "TooltipIconButton";
2207
2207
 
2208
- // src/runtimes/local/LocalRuntime.tsx
2209
- var LocalRuntime = class extends BaseAssistantRuntime {
2210
- _proxyConfigProvider;
2211
- constructor(adapter) {
2212
- const proxyConfigProvider = new ProxyConfigProvider();
2213
- super(new LocalThreadRuntime(proxyConfigProvider, adapter));
2214
- this._proxyConfigProvider = proxyConfigProvider;
2215
- }
2216
- set adapter(adapter) {
2217
- this.thread.adapter = adapter;
2218
- }
2219
- registerModelConfigProvider(provider) {
2220
- return this._proxyConfigProvider.registerModelConfigProvider(provider);
2221
- }
2222
- switchToThread(threadId) {
2223
- if (threadId) {
2224
- throw new Error("LocalRuntime does not yet support switching threads");
2225
- }
2226
- return this.thread = new LocalThreadRuntime(
2227
- this._proxyConfigProvider,
2228
- this.thread.adapter
2229
- );
2230
- }
2231
- };
2232
- var CAPABILITIES = Object.freeze({
2233
- edit: true,
2234
- reload: true,
2235
- cancel: true,
2236
- copy: true
2237
- });
2238
- var LocalThreadRuntime = class {
2239
- constructor(configProvider, adapter) {
2240
- this.configProvider = configProvider;
2241
- this.adapter = adapter;
2242
- }
2243
- _subscriptions = /* @__PURE__ */ new Set();
2244
- abortController = null;
2245
- repository = new MessageRepository();
2246
- capabilities = CAPABILITIES;
2247
- get messages() {
2248
- return this.repository.getMessages();
2249
- }
2250
- get isRunning() {
2251
- return this.abortController != null;
2252
- }
2253
- getBranches(messageId) {
2254
- return this.repository.getBranches(messageId);
2255
- }
2256
- switchToBranch(branchId) {
2257
- this.repository.switchToBranch(branchId);
2258
- this.notifySubscribers();
2259
- }
2260
- async append(message) {
2261
- if (message.role !== "user")
2262
- throw new Error(
2263
- "Only appending user messages are supported in LocalRuntime. This is likely an internal bug in assistant-ui."
2264
- );
2265
- const userMessageId = generateId();
2266
- const userMessage = {
2267
- id: userMessageId,
2268
- role: "user",
2269
- content: message.content,
2270
- createdAt: /* @__PURE__ */ new Date()
2271
- };
2272
- this.repository.addOrUpdateMessage(message.parentId, userMessage);
2273
- await this.startRun(userMessageId);
2274
- }
2275
- async startRun(parentId) {
2276
- this.repository.resetHead(parentId);
2277
- const messages = this.repository.getMessages();
2278
- const message = {
2279
- id: generateId(),
2280
- role: "assistant",
2281
- status: { type: "in_progress" },
2282
- content: [{ type: "text", text: "" }],
2283
- createdAt: /* @__PURE__ */ new Date()
2284
- };
2285
- this.repository.addOrUpdateMessage(parentId, { ...message });
2286
- this.abortController?.abort();
2287
- this.abortController = new AbortController();
2288
- this.notifySubscribers();
2289
- try {
2290
- const updateHandler = ({ content }) => {
2291
- message.content = content;
2292
- this.repository.addOrUpdateMessage(parentId, { ...message });
2293
- this.notifySubscribers();
2294
- };
2295
- const result = await this.adapter.run({
2296
- messages,
2297
- abortSignal: this.abortController.signal,
2298
- config: this.configProvider.getModelConfig(),
2299
- onUpdate: updateHandler
2300
- });
2301
- if (result !== void 0) {
2302
- updateHandler(result);
2303
- }
2304
- if (result.status?.type === "in_progress")
2305
- throw new Error(
2306
- "Unexpected in_progress status returned from ChatModelAdapter"
2307
- );
2308
- message.status = result.status ?? { type: "done" };
2309
- this.repository.addOrUpdateMessage(parentId, { ...message });
2310
- } catch (e) {
2311
- message.status = { type: "error", error: e };
2312
- this.repository.addOrUpdateMessage(parentId, { ...message });
2313
- throw e;
2314
- } finally {
2315
- this.abortController = null;
2316
- this.notifySubscribers();
2317
- }
2318
- }
2319
- cancelRun() {
2320
- if (!this.abortController) return;
2321
- this.abortController.abort();
2322
- this.abortController = null;
2323
- }
2324
- notifySubscribers() {
2325
- for (const callback of this._subscriptions) callback();
2326
- }
2327
- subscribe(callback) {
2328
- this._subscriptions.add(callback);
2329
- return () => this._subscriptions.delete(callback);
2330
- }
2331
- addToolResult({ messageId, toolCallId, result }) {
2332
- const { parentId, message } = this.repository.getMessage(messageId);
2333
- if (message.role !== "assistant")
2334
- throw new Error("Tried to add tool result to non-assistant message");
2335
- let found = false;
2336
- const newContent = message.content.map((c) => {
2337
- if (c.type !== "tool-call") return c;
2338
- if (c.toolCallId !== toolCallId) return c;
2339
- found = true;
2340
- return {
2341
- ...c,
2342
- result
2343
- };
2344
- });
2345
- if (!found)
2346
- throw new Error("Tried to add tool result to non-existing tool call");
2347
- this.repository.addOrUpdateMessage(parentId, {
2348
- ...message,
2349
- content: newContent
2350
- });
2351
- }
2352
- };
2353
-
2354
- // src/runtimes/local/useLocalRuntime.tsx
2355
- var useLocalRuntime = (adapter) => {
2356
- const [runtime] = useState7(() => new LocalRuntime(adapter));
2357
- useInsertionEffect3(() => {
2358
- runtime.adapter = adapter;
2359
- });
2360
- return runtime;
2361
- };
2362
-
2363
2208
  // src/runtimes/edge/useEdgeRuntime.ts
2364
- import { useState as useState8 } from "react";
2209
+ import { useState as useState7 } from "react";
2365
2210
 
2366
2211
  // src/runtimes/edge/converters/toCoreMessages.ts
2367
2212
  var toCoreMessages = (message) => {
@@ -2391,6 +2236,7 @@ var toLanguageModelTools = (tools) => {
2391
2236
 
2392
2237
  // src/runtimes/edge/streams/assistantDecoderStream.ts
2393
2238
  function assistantDecoderStream() {
2239
+ const toolCallNames = /* @__PURE__ */ new Map();
2394
2240
  let currentToolCall;
2395
2241
  return new TransformStream({
2396
2242
  transform(chunk, controller) {
@@ -2415,6 +2261,7 @@ function assistantDecoderStream() {
2415
2261
  }
2416
2262
  case "1" /* ToolCallBegin */: {
2417
2263
  const { id, name } = value;
2264
+ toolCallNames.set(id, name);
2418
2265
  currentToolCall = { id, name, argsText: "" };
2419
2266
  break;
2420
2267
  }
@@ -2430,6 +2277,16 @@ function assistantDecoderStream() {
2430
2277
  });
2431
2278
  break;
2432
2279
  }
2280
+ case "3" /* ToolCallResult */: {
2281
+ controller.enqueue({
2282
+ type: "tool-result",
2283
+ toolCallType: "function",
2284
+ toolCallId: value.id,
2285
+ toolName: toolCallNames.get(value.id),
2286
+ result: value.result
2287
+ });
2288
+ break;
2289
+ }
2433
2290
  case "F" /* Finish */: {
2434
2291
  controller.enqueue({
2435
2292
  type: "finish",
@@ -2816,9 +2673,9 @@ var parsePartialJson = (json) => {
2816
2673
  };
2817
2674
 
2818
2675
  // src/runtimes/edge/streams/runResultStream.ts
2819
- function runResultStream() {
2676
+ function runResultStream(initialContent) {
2820
2677
  let message = {
2821
- content: []
2678
+ content: initialContent
2822
2679
  };
2823
2680
  const currentToolCall = { toolCallId: "", argsText: "" };
2824
2681
  return new TransformStream({
@@ -3001,6 +2858,7 @@ function toolResultStream(tools) {
3001
2858
  }
3002
2859
  case "text-delta":
3003
2860
  case "tool-call-delta":
2861
+ case "tool-result":
3004
2862
  case "finish":
3005
2863
  case "error":
3006
2864
  break;
@@ -3034,7 +2892,7 @@ var EdgeChatAdapter = class {
3034
2892
  constructor(options) {
3035
2893
  this.options = options;
3036
2894
  }
3037
- async run({ messages, abortSignal, config, onUpdate }) {
2895
+ async roundtrip(initialContent, { messages, abortSignal, config, onUpdate }) {
3038
2896
  const result = await fetch(this.options.api, {
3039
2897
  method: "POST",
3040
2898
  headers: {
@@ -3049,21 +2907,56 @@ var EdgeChatAdapter = class {
3049
2907
  }),
3050
2908
  signal: abortSignal
3051
2909
  });
3052
- const stream = result.body.pipeThrough(new TextDecoderStream()).pipeThrough(chunkByLineStream()).pipeThrough(assistantDecoderStream()).pipeThrough(toolResultStream(config.tools)).pipeThrough(runResultStream());
2910
+ const stream = result.body.pipeThrough(new TextDecoderStream()).pipeThrough(chunkByLineStream()).pipeThrough(assistantDecoderStream()).pipeThrough(toolResultStream(config.tools)).pipeThrough(runResultStream(initialContent));
2911
+ let message;
3053
2912
  let update;
3054
2913
  for await (update of asAsyncIterable(stream)) {
3055
- onUpdate(update);
2914
+ message = onUpdate(update);
3056
2915
  }
3057
2916
  if (update === void 0)
3058
2917
  throw new Error("No data received from Edge Runtime");
3059
- return update;
2918
+ return [message, update];
2919
+ }
2920
+ async run({ messages, abortSignal, config, onUpdate }) {
2921
+ let roundtripAllowance = this.options.maxToolRoundtrips ?? 1;
2922
+ let usage = {
2923
+ promptTokens: 0,
2924
+ completionTokens: 0
2925
+ };
2926
+ let result;
2927
+ let assistantMessage;
2928
+ do {
2929
+ [assistantMessage, result] = await this.roundtrip(result?.content ?? [], {
2930
+ messages: assistantMessage ? [...messages, assistantMessage] : messages,
2931
+ abortSignal,
2932
+ config,
2933
+ onUpdate
2934
+ });
2935
+ if (result.status?.type === "done") {
2936
+ usage.promptTokens += result.status.usage?.promptTokens ?? 0;
2937
+ usage.completionTokens += result.status.usage?.completionTokens ?? 0;
2938
+ }
2939
+ } while (result.status?.type === "done" && result.status.finishReason === "tool-calls" && result.content.every((c) => c.type !== "tool-call" || !!c.result) && roundtripAllowance-- > 0);
2940
+ if (result.status?.type === "done" && usage.promptTokens > 0) {
2941
+ result = {
2942
+ ...result,
2943
+ status: {
2944
+ ...result.status,
2945
+ usage
2946
+ }
2947
+ };
2948
+ }
2949
+ return result;
3060
2950
  }
3061
2951
  };
3062
2952
 
3063
2953
  // src/runtimes/edge/useEdgeRuntime.ts
3064
- var useEdgeRuntime = (options) => {
3065
- const [adapter] = useState8(() => new EdgeChatAdapter(options));
3066
- return useLocalRuntime(adapter);
2954
+ var useEdgeRuntime = ({
2955
+ initialMessages,
2956
+ ...options
2957
+ }) => {
2958
+ const [adapter] = useState7(() => new EdgeChatAdapter(options));
2959
+ return useLocalRuntime(adapter, { initialMessages });
3067
2960
  };
3068
2961
 
3069
2962
  // src/runtimes/edge/converters/toLanguageModelMessages.ts
@@ -3290,16 +3183,181 @@ var fromLanguageModelMessages = (lm, mergeRoundtrips) => {
3290
3183
  var fromCoreMessages = (message) => {
3291
3184
  return message.map((message2) => {
3292
3185
  return {
3293
- ...message2,
3294
3186
  id: generateId(),
3295
3187
  createdAt: /* @__PURE__ */ new Date(),
3296
3188
  ...message2.role === "assistant" ? {
3297
3189
  status: { type: "done" }
3298
- } : void 0
3190
+ } : void 0,
3191
+ ...message2
3299
3192
  };
3300
3193
  });
3301
3194
  };
3302
3195
 
3196
+ // src/runtimes/local/LocalRuntime.tsx
3197
+ var LocalRuntime = class extends BaseAssistantRuntime {
3198
+ _proxyConfigProvider;
3199
+ constructor(adapter, options) {
3200
+ const proxyConfigProvider = new ProxyConfigProvider();
3201
+ super(new LocalThreadRuntime(proxyConfigProvider, adapter, options));
3202
+ this._proxyConfigProvider = proxyConfigProvider;
3203
+ }
3204
+ set adapter(adapter) {
3205
+ this.thread.adapter = adapter;
3206
+ }
3207
+ registerModelConfigProvider(provider) {
3208
+ return this._proxyConfigProvider.registerModelConfigProvider(provider);
3209
+ }
3210
+ switchToThread(threadId) {
3211
+ if (threadId) {
3212
+ throw new Error("LocalRuntime does not yet support switching threads");
3213
+ }
3214
+ return this.thread = new LocalThreadRuntime(
3215
+ this._proxyConfigProvider,
3216
+ this.thread.adapter
3217
+ );
3218
+ }
3219
+ };
3220
+ var CAPABILITIES = Object.freeze({
3221
+ edit: true,
3222
+ reload: true,
3223
+ cancel: true,
3224
+ copy: true
3225
+ });
3226
+ var LocalThreadRuntime = class {
3227
+ constructor(configProvider, adapter, options) {
3228
+ this.configProvider = configProvider;
3229
+ this.adapter = adapter;
3230
+ if (options?.initialMessages) {
3231
+ let parentId = null;
3232
+ const messages = fromCoreMessages(options.initialMessages);
3233
+ for (const message of messages) {
3234
+ this.repository.addOrUpdateMessage(parentId, message);
3235
+ parentId = message.id;
3236
+ }
3237
+ }
3238
+ }
3239
+ _subscriptions = /* @__PURE__ */ new Set();
3240
+ abortController = null;
3241
+ repository = new MessageRepository();
3242
+ capabilities = CAPABILITIES;
3243
+ get messages() {
3244
+ return this.repository.getMessages();
3245
+ }
3246
+ get isRunning() {
3247
+ return this.abortController != null;
3248
+ }
3249
+ getBranches(messageId) {
3250
+ return this.repository.getBranches(messageId);
3251
+ }
3252
+ switchToBranch(branchId) {
3253
+ this.repository.switchToBranch(branchId);
3254
+ this.notifySubscribers();
3255
+ }
3256
+ async append(message) {
3257
+ if (message.role !== "user")
3258
+ throw new Error(
3259
+ "Only appending user messages are supported in LocalRuntime. This is likely an internal bug in assistant-ui."
3260
+ );
3261
+ const userMessageId = generateId();
3262
+ const userMessage = {
3263
+ id: userMessageId,
3264
+ role: "user",
3265
+ content: message.content,
3266
+ createdAt: /* @__PURE__ */ new Date()
3267
+ };
3268
+ this.repository.addOrUpdateMessage(message.parentId, userMessage);
3269
+ await this.startRun(userMessageId);
3270
+ }
3271
+ async startRun(parentId) {
3272
+ this.repository.resetHead(parentId);
3273
+ const messages = this.repository.getMessages();
3274
+ const message = {
3275
+ id: generateId(),
3276
+ role: "assistant",
3277
+ status: { type: "in_progress" },
3278
+ content: [{ type: "text", text: "" }],
3279
+ createdAt: /* @__PURE__ */ new Date()
3280
+ };
3281
+ this.repository.addOrUpdateMessage(parentId, { ...message });
3282
+ this.abortController?.abort();
3283
+ this.abortController = new AbortController();
3284
+ this.notifySubscribers();
3285
+ try {
3286
+ const updateHandler = ({ content }) => {
3287
+ message.content = content;
3288
+ const newMessage = { ...message };
3289
+ this.repository.addOrUpdateMessage(parentId, newMessage);
3290
+ this.notifySubscribers();
3291
+ return newMessage;
3292
+ };
3293
+ const result = await this.adapter.run({
3294
+ messages,
3295
+ abortSignal: this.abortController.signal,
3296
+ config: this.configProvider.getModelConfig(),
3297
+ onUpdate: updateHandler
3298
+ });
3299
+ if (result !== void 0) {
3300
+ updateHandler(result);
3301
+ }
3302
+ if (result.status?.type === "in_progress")
3303
+ throw new Error(
3304
+ "Unexpected in_progress status returned from ChatModelAdapter"
3305
+ );
3306
+ message.status = result.status ?? { type: "done" };
3307
+ this.repository.addOrUpdateMessage(parentId, { ...message });
3308
+ } catch (e) {
3309
+ message.status = { type: "error", error: e };
3310
+ this.repository.addOrUpdateMessage(parentId, { ...message });
3311
+ throw e;
3312
+ } finally {
3313
+ this.abortController = null;
3314
+ this.notifySubscribers();
3315
+ }
3316
+ }
3317
+ cancelRun() {
3318
+ if (!this.abortController) return;
3319
+ this.abortController.abort();
3320
+ this.abortController = null;
3321
+ }
3322
+ notifySubscribers() {
3323
+ for (const callback of this._subscriptions) callback();
3324
+ }
3325
+ subscribe(callback) {
3326
+ this._subscriptions.add(callback);
3327
+ return () => this._subscriptions.delete(callback);
3328
+ }
3329
+ addToolResult({ messageId, toolCallId, result }) {
3330
+ const { parentId, message } = this.repository.getMessage(messageId);
3331
+ if (message.role !== "assistant")
3332
+ throw new Error("Tried to add tool result to non-assistant message");
3333
+ let found = false;
3334
+ const newContent = message.content.map((c) => {
3335
+ if (c.type !== "tool-call") return c;
3336
+ if (c.toolCallId !== toolCallId) return c;
3337
+ found = true;
3338
+ return {
3339
+ ...c,
3340
+ result
3341
+ };
3342
+ });
3343
+ if (!found)
3344
+ throw new Error("Tried to add tool result to non-existing tool call");
3345
+ this.repository.addOrUpdateMessage(parentId, {
3346
+ ...message,
3347
+ content: newContent
3348
+ });
3349
+ }
3350
+ };
3351
+
3352
+ // src/runtimes/local/useLocalRuntime.tsx
3353
+ var useLocalRuntime = (adapter, options) => {
3354
+ const [runtime] = useState8(() => new LocalRuntime(adapter, options));
3355
+ useInsertionEffect3(() => {
3356
+ runtime.adapter = adapter;
3357
+ });
3358
+ return runtime;
3359
+ };
3360
+
3303
3361
  // src/ui/thread-config.tsx
3304
3362
  import { createContext as createContext5, useContext as useContext5 } from "react";
3305
3363
  import { Fragment as Fragment3, jsx as jsx29 } from "react/jsx-runtime";
@@ -3895,10 +3953,7 @@ var ThreadScrollToBottom = forwardRef26((props, ref) => {
3895
3953
  thread: { scrollToBottom: { tooltip = "Scroll to bottom" } = {} } = {}
3896
3954
  } = {}
3897
3955
  } = useThreadConfig();
3898
- return /* @__PURE__ */ jsx41(thread_exports.ScrollToBottom, { asChild: true, children: /* @__PURE__ */ jsxs13(ThreadScrollToBottomIconButton, { tooltip, ...props, ref, children: [
3899
- "|",
3900
- props.children ?? /* @__PURE__ */ jsx41(ArrowDownIcon, {})
3901
- ] }) });
3956
+ return /* @__PURE__ */ jsx41(thread_exports.ScrollToBottom, { asChild: true, children: /* @__PURE__ */ jsx41(ThreadScrollToBottomIconButton, { tooltip, ...props, ref, children: props.children ?? /* @__PURE__ */ jsx41(ArrowDownIcon, {}) }) });
3902
3957
  });
3903
3958
  ThreadScrollToBottom.displayName = "ThreadScrollToBottom";
3904
3959
  var exports10 = {