@langchain/langgraph-sdk 0.0.40 → 0.0.41

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -58,6 +58,92 @@ function findLastIndex(array, predicate) {
58
58
  }
59
59
  return -1;
60
60
  }
61
+ function getBranchSequence(history) {
62
+ const childrenMap = {};
63
+ // First pass - collect nodes for each checkpoint
64
+ history.forEach((state) => {
65
+ const checkpointId = state.parent_checkpoint?.checkpoint_id ?? "$";
66
+ childrenMap[checkpointId] ??= [];
67
+ childrenMap[checkpointId].push(state);
68
+ });
69
+ const rootSequence = { type: "sequence", items: [] };
70
+ const queue = [{ id: "$", sequence: rootSequence, path: [] }];
71
+ const paths = [];
72
+ const visited = new Set();
73
+ while (queue.length > 0) {
74
+ const task = queue.shift();
75
+ if (visited.has(task.id))
76
+ continue;
77
+ visited.add(task.id);
78
+ const children = childrenMap[task.id];
79
+ if (children == null || children.length === 0)
80
+ continue;
81
+ // If we've encountered a fork (2+ children), push the fork
82
+ // to the sequence and add a new sequence for each child
83
+ let fork;
84
+ if (children.length > 1) {
85
+ fork = { type: "fork", items: [] };
86
+ task.sequence.items.push(fork);
87
+ }
88
+ for (const value of children) {
89
+ const id = value.checkpoint.checkpoint_id;
90
+ let sequence = task.sequence;
91
+ let path = task.path;
92
+ if (fork != null) {
93
+ sequence = { type: "sequence", items: [] };
94
+ fork.items.unshift(sequence);
95
+ path = path.slice();
96
+ path.push(id);
97
+ paths.push(path);
98
+ }
99
+ sequence.items.push({ type: "node", value, path });
100
+ queue.push({ id, sequence, path });
101
+ }
102
+ }
103
+ return { rootSequence, paths };
104
+ }
105
+ const PATH_SEP = ">";
106
+ const ROOT_ID = "$";
107
+ // Get flat view
108
+ function getBranchView(sequence, paths, branch) {
109
+ const path = branch.split(PATH_SEP);
110
+ const pathMap = {};
111
+ for (const path of paths) {
112
+ const parent = path.at(-2) ?? ROOT_ID;
113
+ pathMap[parent] ??= [];
114
+ pathMap[parent].unshift(path);
115
+ }
116
+ const history = [];
117
+ const branchByCheckpoint = {};
118
+ const forkStack = path.slice();
119
+ const queue = [...sequence.items];
120
+ while (queue.length > 0) {
121
+ const item = queue.shift();
122
+ if (item.type === "node") {
123
+ history.push(item.value);
124
+ branchByCheckpoint[item.value.checkpoint.checkpoint_id] = {
125
+ branch: item.path.join(PATH_SEP),
126
+ branchOptions: (item.path.length > 0
127
+ ? pathMap[item.path.at(-2) ?? ROOT_ID] ?? []
128
+ : []).map((p) => p.join(PATH_SEP)),
129
+ };
130
+ }
131
+ if (item.type === "fork") {
132
+ const forkId = forkStack.shift();
133
+ const index = forkId != null
134
+ ? item.items.findIndex((value) => {
135
+ const firstItem = value.items.at(0);
136
+ if (!firstItem || firstItem.type !== "node")
137
+ return false;
138
+ return firstItem.value.checkpoint.checkpoint_id === forkId;
139
+ })
140
+ : -1;
141
+ const nextItems = item.items.at(index)?.items ?? [];
142
+ queue.push(...nextItems);
143
+ }
144
+ }
145
+ return { history, branchByCheckpoint };
146
+ }
61
147
  function fetchHistory(client, threadId) {
62
148
  return client.threads.getHistory(threadId, { limit: 1000 });
63
149
  }
@@ -84,12 +170,26 @@ function useThreadHistory(threadId, client, clearCallbackRef, submittingRef) {
84
170
  mutate: (mutateId) => fetcher(mutateId ?? threadId),
85
171
  };
86
172
  }
173
+ const useControllableThreadId = (options) => {
174
+ const [localThreadId, _setLocalThreadId] = (0, react_1.useState)(options?.threadId ?? null);
175
+ const onThreadIdRef = (0, react_1.useRef)(options?.onThreadId);
176
+ onThreadIdRef.current = options?.onThreadId;
177
+ const onThreadId = (0, react_1.useCallback)((threadId) => {
178
+ _setLocalThreadId(threadId);
179
+ onThreadIdRef.current?.(threadId);
180
+ }, []);
181
+ if (typeof options?.threadId === "undefined") {
182
+ return [localThreadId, onThreadId];
183
+ }
184
+ return [options.threadId, onThreadId];
185
+ };
87
186
  function useStream(options) {
88
- const { assistantId, threadId, withMessages, onError, onFinish } = options;
187
+ let { assistantId, messagesKey, onError, onFinish } = options;
188
+ messagesKey ??= "messages";
89
189
  const client = (0, react_1.useMemo)(() => new client_js_1.Client({ apiUrl: options.apiUrl, apiKey: options.apiKey }), [options.apiKey, options.apiUrl]);
90
- const [branchPath, setBranchPath] = (0, react_1.useState)([]);
190
+ const [threadId, onThreadId] = useControllableThreadId(options);
191
+ const [branch, setBranch] = (0, react_1.useState)("");
91
192
  const [isLoading, setIsLoading] = (0, react_1.useState)(false);
92
- const [_, setEvents] = (0, react_1.useState)([]);
93
193
  const [streamError, setStreamError] = (0, react_1.useState)(undefined);
94
194
  const [streamValues, setStreamValues] = (0, react_1.useState)(null);
95
195
  const messageManagerRef = (0, react_1.useRef)(new MessageTupleManager());
@@ -119,94 +219,13 @@ function useStream(options) {
119
219
  // TODO: should we permit adapter? SWR / React Query?
120
220
  const history = useThreadHistory(threadId, client, clearCallbackRef, submittingRef);
121
221
  const getMessages = (0, react_1.useMemo)(() => {
122
- if (withMessages == null)
123
- return undefined;
124
- return (value) => Array.isArray(value[withMessages])
125
- ? value[withMessages]
222
+ return (value) => Array.isArray(value[messagesKey])
223
+ ? value[messagesKey]
126
224
  : [];
127
- }, [withMessages]);
128
- const [sequence, pathMap] = (() => {
129
- const childrenMap = {};
130
- // First pass - collect nodes for each checkpoint
131
- history.data.forEach((state) => {
132
- const checkpointId = state.parent_checkpoint?.checkpoint_id ?? "$";
133
- childrenMap[checkpointId] ??= [];
134
- childrenMap[checkpointId].push(state);
135
- });
136
- const rootSequence = { type: "sequence", items: [] };
137
- const queue = [{ id: "$", sequence: rootSequence, path: [] }];
138
- const paths = [];
139
- const visited = new Set();
140
- while (queue.length > 0) {
141
- const task = queue.shift();
142
- if (visited.has(task.id))
143
- continue;
144
- visited.add(task.id);
145
- const children = childrenMap[task.id];
146
- if (children == null || children.length === 0)
147
- continue;
148
- // If we've encountered a fork (2+ children), push the fork
149
- // to the sequence and add a new sequence for each child
150
- let fork;
151
- if (children.length > 1) {
152
- fork = { type: "fork", items: [] };
153
- task.sequence.items.push(fork);
154
- }
155
- for (const value of children) {
156
- const id = value.checkpoint.checkpoint_id;
157
- let sequence = task.sequence;
158
- let path = task.path;
159
- if (fork != null) {
160
- sequence = { type: "sequence", items: [] };
161
- fork.items.unshift(sequence);
162
- path = path.slice();
163
- path.push(id);
164
- paths.push(path);
165
- }
166
- sequence.items.push({ type: "node", value, path });
167
- queue.push({ id, sequence, path });
168
- }
169
- }
170
- // Third pass, create a map for available forks
171
- const pathMap = {};
172
- for (const path of paths) {
173
- const parent = path.at(-2) ?? "$";
174
- pathMap[parent] ??= [];
175
- pathMap[parent].unshift(path);
176
- }
177
- return [rootSequence, pathMap];
178
- })();
179
- const [flatValues, flatPaths] = (() => {
180
- const result = [];
181
- const flatPaths = {};
182
- const forkStack = branchPath.slice();
183
- const queue = [...sequence.items];
184
- while (queue.length > 0) {
185
- const item = queue.shift();
186
- if (item.type === "node") {
187
- result.push(item.value);
188
- flatPaths[item.value.checkpoint.checkpoint_id] = {
189
- current: item.path,
190
- branches: item.path.length > 0 ? pathMap[item.path.at(-2) ?? "$"] ?? [] : [],
191
- };
192
- }
193
- if (item.type === "fork") {
194
- const forkId = forkStack.shift();
195
- const index = forkId != null
196
- ? item.items.findIndex((value) => {
197
- const firstItem = value.items.at(0);
198
- if (!firstItem || firstItem.type !== "node")
199
- return false;
200
- return firstItem.value.checkpoint.checkpoint_id === forkId;
201
- })
202
- : -1;
203
- const nextItems = item.items.at(index)?.items ?? [];
204
- queue.push(...nextItems);
205
- }
206
- }
207
- return [result, flatPaths];
208
- })();
209
- const threadHead = flatValues.at(-1);
225
+ }, [messagesKey]);
226
+ const { rootSequence, paths } = getBranchSequence(history.data);
227
+ const { history: flatHistory, branchByCheckpoint } = getBranchView(rootSequence, paths, branch);
228
+ const threadHead = flatHistory.at(-1);
210
229
  const historyValues = threadHead?.values ?? {};
211
230
  const historyError = (() => {
212
231
  const error = threadHead?.tasks?.at(-1)?.error;
@@ -225,8 +244,6 @@ function useStream(options) {
225
244
  return error;
226
245
  })();
227
246
  const messageMetadata = (() => {
228
- if (getMessages == null)
229
- return undefined;
230
247
  const alreadyShown = new Set();
231
248
  return getMessages(historyValues).map((message, idx) => {
232
249
  const messageId = message.id ?? idx;
@@ -235,12 +252,12 @@ function useStream(options) {
235
252
  .includes(messageId));
236
253
  const firstSeen = history.data[firstSeenIdx];
237
254
  let branch = firstSeen
238
- ? flatPaths[firstSeen.checkpoint.checkpoint_id]
255
+ ? branchByCheckpoint[firstSeen.checkpoint.checkpoint_id]
239
256
  : undefined;
240
- if (!branch?.current?.length)
257
+ if (!branch?.branch?.length)
241
258
  branch = undefined;
242
259
  // serialize branches
243
- const optionsShown = branch?.branches?.flat(2).join(",");
260
+ const optionsShown = branch?.branchOptions?.flat(2).join(",");
244
261
  if (optionsShown) {
245
262
  if (alreadyShown.has(optionsShown))
246
263
  branch = undefined;
@@ -249,8 +266,8 @@ function useStream(options) {
249
266
  return {
250
267
  messageId: messageId.toString(),
251
268
  firstSeenState: firstSeen,
252
- branch: branch?.current?.join(">"),
253
- branchOptions: branch?.branches?.map((b) => b.join(">")),
269
+ branch: branch?.branch,
270
+ branchOptions: branch?.branchOptions,
254
271
  };
255
272
  });
256
273
  })();
@@ -268,7 +285,7 @@ function useStream(options) {
268
285
  let usableThreadId = threadId;
269
286
  if (!usableThreadId) {
270
287
  const thread = await client.threads.create();
271
- options?.onThreadId?.(thread.thread_id);
288
+ onThreadId(thread.thread_id);
272
289
  usableThreadId = thread.thread_id;
273
290
  }
274
291
  const streamMode = unique([
@@ -296,10 +313,10 @@ function useStream(options) {
296
313
  }));
297
314
  // Unbranch things
298
315
  const newPath = submitOptions?.checkpoint?.checkpoint_id
299
- ? flatPaths[submitOptions?.checkpoint?.checkpoint_id]?.current
316
+ ? branchByCheckpoint[submitOptions?.checkpoint?.checkpoint_id]?.branch
300
317
  : undefined;
301
318
  if (newPath != null)
302
- setBranchPath(newPath ?? []);
319
+ setBranch(newPath ?? "");
303
320
  // Assumption: we're setting the initial value
304
321
  // Used for instant feedback
305
322
  setStreamValues(() => {
@@ -316,26 +333,19 @@ function useStream(options) {
316
333
  });
317
334
  let streamError;
318
335
  for await (const { event, data } of run) {
319
- setEvents((events) => [...events, { event, data }]);
320
336
  if (event === "error") {
321
337
  streamError = new StreamError(data);
322
338
  break;
323
339
  }
324
- if (event === "updates") {
340
+ if (event === "updates")
325
341
  options.onUpdateEvent?.(data);
326
- }
327
- if (event === "custom") {
342
+ if (event === "custom")
328
343
  options.onCustomEvent?.(data);
329
- }
330
- if (event === "metadata") {
344
+ if (event === "metadata")
331
345
  options.onMetadataEvent?.(data);
332
- }
333
- if (event === "values") {
346
+ if (event === "values")
334
347
  setStreamValues(data);
335
- }
336
348
  if (event === "messages") {
337
- if (!getMessages)
338
- continue;
339
349
  const [serialized] = data;
340
350
  const messageId = messageManagerRef.current.add(serialized);
341
351
  if (!messageId) {
@@ -350,13 +360,12 @@ function useStream(options) {
350
360
  if (!chunk || index == null)
351
361
  return values;
352
362
  messages[index] = toMessageDict(chunk);
353
- return { ...values, [withMessages]: messages };
363
+ return { ...values, [messagesKey]: messages };
354
364
  });
355
365
  }
356
366
  }
357
367
  // TODO: stream created checkpoints to avoid an unnecessary network request
358
368
  const result = await history.mutate(usableThreadId);
359
- // TODO: write tests verifying that stream values are properly handled lifecycle-wise
360
369
  setStreamValues(null);
361
370
  if (streamError != null)
362
371
  throw streamError;
@@ -381,7 +390,6 @@ function useStream(options) {
381
390
  };
382
391
  const error = isLoading ? streamError : historyError;
383
392
  const values = streamValues ?? historyValues;
384
- const setBranch = (0, react_1.useCallback)((path) => setBranchPath(path.split(">")), [setBranchPath]);
385
393
  return {
386
394
  get values() {
387
395
  trackStreamMode("values");
@@ -391,19 +399,16 @@ function useStream(options) {
391
399
  isLoading,
392
400
  stop,
393
401
  submit,
402
+ branch,
394
403
  setBranch,
404
+ history: flatHistory,
405
+ experimental_branchTree: rootSequence,
395
406
  get messages() {
396
407
  trackStreamMode("messages-tuple");
397
- if (getMessages == null) {
398
- throw new Error("No messages key provided. Make sure that `useStream` contains the `messagesKey` property.");
399
- }
400
408
  return getMessages(values);
401
409
  },
402
410
  getMessagesMetadata(message, index) {
403
411
  trackStreamMode("messages-tuple");
404
- if (getMessages == null) {
405
- throw new Error("No messages key provided. Make sure that `useStream` contains the `messagesKey` property.");
406
- }
407
412
  return messageMetadata?.find((m) => m.messageId === (message.id ?? index));
408
413
  },
409
414
  };
@@ -3,44 +3,153 @@ import type { Command, DisconnectMode, MultitaskStrategy, OnCompletionBehavior }
3
3
  import type { Message } from "../types.messages.js";
4
4
  import type { Checkpoint, Config, Metadata, ThreadState } from "../schema.js";
5
5
  import type { CustomStreamEvent, MetadataStreamEvent, StreamMode, UpdatesStreamEvent } from "../types.stream.js";
6
+ interface Node<StateType = any> {
7
+ type: "node";
8
+ value: ThreadState<StateType>;
9
+ path: string[];
10
+ }
11
+ interface Fork<StateType = any> {
12
+ type: "fork";
13
+ items: Array<Sequence<StateType>>;
14
+ }
15
+ interface Sequence<StateType = any> {
16
+ type: "sequence";
17
+ items: Array<Node<StateType> | Fork<StateType>>;
18
+ }
6
19
  export type MessageMetadata<StateType extends Record<string, unknown>> = {
20
+ /**
21
+ * The ID of the message used.
22
+ */
7
23
  messageId: string;
24
+ /**
25
+ * The first thread state the message was seen in.
26
+ */
8
27
  firstSeenState: ThreadState<StateType> | undefined;
28
+ /**
29
+ * The branch of the message.
30
+ */
9
31
  branch: string | undefined;
32
+ /**
33
+ * The list of branches this message is part of.
34
+ * This is useful for displaying branching controls.
35
+ */
10
36
  branchOptions: string[] | undefined;
11
37
  };
12
- export declare function useStream<StateType extends Record<string, unknown> = Record<string, unknown>, UpdateType extends Record<string, unknown> = Partial<StateType>, CustomType = unknown>(options: {
38
+ interface UseStreamOptions<StateType extends Record<string, unknown> = Record<string, unknown>, UpdateType extends Record<string, unknown> = Partial<StateType>, CustomType = unknown> {
39
+ /**
40
+ * The ID of the assistant to use.
41
+ */
13
42
  assistantId: string;
43
+ /**
44
+ * The URL of the API to use.
45
+ */
14
46
  apiUrl: ClientConfig["apiUrl"];
47
+ /**
48
+ * The API key to use.
49
+ */
15
50
  apiKey?: ClientConfig["apiKey"];
16
- withMessages?: string;
51
+ /**
52
+ * Specify the key within the state that contains messages.
53
+ * Defaults to "messages".
54
+ *
55
+ * @default "messages"
56
+ */
57
+ messagesKey?: string;
58
+ /**
59
+ * Callback that is called when an error occurs.
60
+ */
17
61
  onError?: (error: unknown) => void;
62
+ /**
63
+ * Callback that is called when the stream is finished.
64
+ */
18
65
  onFinish?: (state: ThreadState<StateType>) => void;
66
+ /**
67
+ * Callback that is called when an update event is received.
68
+ */
19
69
  onUpdateEvent?: (data: UpdatesStreamEvent<UpdateType>["data"]) => void;
70
+ /**
71
+ * Callback that is called when a custom event is received.
72
+ */
20
73
  onCustomEvent?: (data: CustomStreamEvent<CustomType>["data"]) => void;
74
+ /**
75
+ * Callback that is called when a metadata event is received.
76
+ */
21
77
  onMetadataEvent?: (data: MetadataStreamEvent["data"]) => void;
78
+ /**
79
+ * The ID of the thread to fetch history and current values from.
80
+ */
22
81
  threadId?: string | null;
82
+ /**
83
+ * Callback that is called when the thread ID is updated (ie when a new thread is created).
84
+ */
23
85
  onThreadId?: (threadId: string) => void;
24
- }): {
25
- readonly values: StateType;
86
+ }
87
+ interface UseStream<StateType extends Record<string, unknown> = Record<string, unknown>, UpdateType extends Record<string, unknown> = Partial<StateType>> {
88
+ /**
89
+ * The current values of the thread.
90
+ */
91
+ values: StateType;
92
+ /**
93
+ * Last seen error from the thread or during streaming.
94
+ */
26
95
  error: unknown;
96
+ /**
97
+ * Whether the stream is currently running.
98
+ */
27
99
  isLoading: boolean;
100
+ /**
101
+ * Stops the stream.
102
+ */
28
103
  stop: () => void;
29
- submit: (values: UpdateType | undefined, submitOptions?: {
30
- config?: Config;
31
- checkpoint?: Omit<Checkpoint, "thread_id"> | null;
32
- command?: Command;
33
- interruptBefore?: "*" | string[];
34
- interruptAfter?: "*" | string[];
35
- metadata?: Metadata;
36
- multitaskStrategy?: MultitaskStrategy;
37
- onCompletion?: OnCompletionBehavior;
38
- onDisconnect?: DisconnectMode;
39
- feedbackKeys?: string[];
40
- streamMode?: Array<StreamMode>;
41
- optimisticValues?: Partial<StateType> | ((prev: StateType) => Partial<StateType>);
42
- }) => Promise<void>;
43
- setBranch: (path: string) => void;
44
- readonly messages: Message[];
45
- getMessagesMetadata(message: Message, index?: number): MessageMetadata<StateType> | undefined;
46
- };
104
+ /**
105
+ * Create and stream a run to the thread.
106
+ */
107
+ submit: (values: UpdateType, options?: SubmitOptions<StateType>) => void;
108
+ /**
109
+ * The current branch of the thread.
110
+ */
111
+ branch: string;
112
+ /**
113
+ * Set the branch of the thread.
114
+ */
115
+ setBranch: (branch: string) => void;
116
+ /**
117
+ * Flattened history of thread states of a thread.
118
+ */
119
+ history: ThreadState<StateType>[];
120
+ /**
121
+ * Tree of all branches for the thread.
122
+ * @experimental
123
+ */
124
+ experimental_branchTree: Sequence<StateType>;
125
+ /**
126
+ * Messages inferred from the thread.
127
+ * Will automatically update with incoming message chunks.
128
+ */
129
+ messages: Message[];
130
+ /**
131
+ * Get the metadata for a message, such as first thread state the message
132
+ * was seen in and branch information.
133
+
134
+ * @param message - The message to get the metadata for.
135
+ * @param index - The index of the message in the thread.
136
+ * @returns The metadata for the message.
137
+ */
138
+ getMessagesMetadata: (message: Message, index?: number) => MessageMetadata<StateType> | undefined;
139
+ }
140
+ interface SubmitOptions<StateType extends Record<string, unknown> = Record<string, unknown>> {
141
+ config?: Config;
142
+ checkpoint?: Omit<Checkpoint, "thread_id"> | null;
143
+ command?: Command;
144
+ interruptBefore?: "*" | string[];
145
+ interruptAfter?: "*" | string[];
146
+ metadata?: Metadata;
147
+ multitaskStrategy?: MultitaskStrategy;
148
+ onCompletion?: OnCompletionBehavior;
149
+ onDisconnect?: DisconnectMode;
150
+ feedbackKeys?: string[];
151
+ streamMode?: Array<StreamMode>;
152
+ optimisticValues?: Partial<StateType> | ((prev: StateType) => Partial<StateType>);
153
+ }
154
+ export declare function useStream<StateType extends Record<string, unknown> = Record<string, unknown>, UpdateType extends Record<string, unknown> = Partial<StateType>, CustomType = unknown>(options: UseStreamOptions<StateType, UpdateType, CustomType>): UseStream<StateType, UpdateType>;
155
+ export {};