ai-sdk-graph 0.5.0 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,16 +9,20 @@ export declare class CompiledGraph<State extends Record<string, unknown>, NodeKe
9
9
  private readonly edgeRegistry;
10
10
  private readonly subgraphRegistry;
11
11
  private readonly storage;
12
- private readonly emitter;
13
12
  private readonly stateManager;
14
- private readonly onFinish;
15
- private readonly onStart;
13
+ private readonly graphMiddleware;
14
+ private readonly nodeMiddleware;
15
+ private readonly stateMiddleware;
16
+ private readonly eventMiddleware;
16
17
  constructor(nodeRegistry: ReadonlyMap<NodeKeys, GraphSDK.Node<State, NodeKeys>>, edgeRegistry: ReadonlyMap<NodeKeys, GraphSDK.Edge<State, NodeKeys>[]>, subgraphRegistry: ReadonlyMap<NodeKeys, {
17
18
  subgraph: Graph<any, any>;
18
19
  options: GraphSDK.SubgraphOptions<State, any>;
19
- }>, options?: GraphSDK.CompileOptions<State, NodeKeys>);
20
- execute(runId: string, initialState: State | ((state: State | undefined) => State)): ReadableStream<import("ai").InferUIMessageChunk<import("ai").UIMessage<unknown, import("ai").UIDataTypes, import("ai").UITools>>>;
21
- executeInternal(runId: string, initialState: State, writer: GraphSDK.Writer): Promise<State>;
20
+ }>, options?: GraphSDK.CompileOptions<State, NodeKeys>, graphMiddleware?: GraphSDK.GraphMiddleware<State, NodeKeys>[], nodeMiddleware?: GraphSDK.NodeMiddleware<State, NodeKeys>[], stateMiddleware?: GraphSDK.StateMiddleware<State, NodeKeys>[], eventMiddleware?: GraphSDK.EventMiddleware<State, NodeKeys>[]);
21
+ stream(runId: string, initialState: State | ((state: State | undefined) => State)): ReadableStream<import("ai").InferUIMessageChunk<import("ai").UIMessage<unknown, import("ai").UIDataTypes, import("ai").UITools>>>;
22
+ execute(runId: string, initialState: State | ((state: State | undefined) => State), options?: {
23
+ onEvent?: (event: GraphSDK.GraphEvent<State, NodeKeys>) => void;
24
+ }): Promise<State>;
25
+ executeInternal(runId: string, initialState: State, writer: GraphSDK.Writer, emit: (event: GraphSDK.GraphEvent<State, NodeKeys>) => void): Promise<State>;
22
26
  private createExecutionContext;
23
27
  private runExecutionLoop;
24
28
  private runExecutionLoopInternal;
package/dist/graph.d.ts CHANGED
@@ -5,12 +5,13 @@ export declare class Graph<State extends Record<string, unknown>, NodeKeys exten
5
5
  private readonly nodeRegistry;
6
6
  private readonly edgeRegistry;
7
7
  private readonly subgraphRegistry;
8
+ private readonly middlewares;
8
9
  constructor();
9
10
  node<NewKey extends string>(id: NewKey, execute: ({ state, writer, suspense, update }: {
10
11
  state: () => Readonly<State>;
11
12
  writer: GraphSDK.Writer;
12
13
  suspense: (data?: unknown) => never;
13
- update: (update: GraphSDK.StateUpdate<State>) => void;
14
+ update: (update: GraphSDK.StateUpdate<State>) => Promise<void>;
14
15
  }) => Promise<void> | void): Graph<State, NodeKeys | NewKey>;
15
16
  edge(from: NodeKeys, to: NodeKeys | ((state: State) => NodeKeys)): Graph<State, NodeKeys>;
16
17
  graph<NewKey extends string, ChildState extends Record<string, unknown>, ChildNodeKeys extends string = 'START' | 'END'>(id: NewKey, subgraph: Graph<ChildState, ChildNodeKeys>, options: GraphSDK.SubgraphOptions<State, ChildState>): Graph<State, NodeKeys | NewKey>;
@@ -20,6 +21,7 @@ export declare class Graph<State extends Record<string, unknown>, NodeKeys exten
20
21
  subgraph: Graph<any, any>;
21
22
  options: GraphSDK.SubgraphOptions<State, any>;
22
23
  }>;
24
+ use(middleware: GraphSDK.Middleware<State, NodeKeys>): Graph<State, NodeKeys>;
23
25
  compile(options?: GraphSDK.CompileOptions<State, NodeKeys>): CompiledGraph<State, NodeKeys>;
24
26
  toMermaid(options?: {
25
27
  direction?: 'TB' | 'LR';
package/dist/index.d.ts CHANGED
@@ -3,3 +3,4 @@ export * from './compiled-graph';
3
3
  export * from './storage';
4
4
  export * from './types';
5
5
  export * from './utils';
6
+ export * from './middleware';
package/dist/index.js CHANGED
@@ -33,6 +33,21 @@ class RedisStorage {
33
33
 
34
34
  // src/compiled-graph.ts
35
35
  import { createUIMessageStream } from "ai";
36
+
37
+ // src/middleware.ts
38
+ function composeMiddleware(middlewares, action) {
39
+ return (ctx) => {
40
+ function dispatch(i) {
41
+ if (i === middlewares.length) {
42
+ return action(ctx);
43
+ }
44
+ return middlewares[i](ctx, () => dispatch(i + 1));
45
+ }
46
+ return dispatch(0);
47
+ };
48
+ }
49
+
50
+ // src/compiled-graph.ts
36
51
  var BUILT_IN_NODES = {
37
52
  START: "START",
38
53
  END: "END"
@@ -46,51 +61,108 @@ class SuspenseError extends Error {
46
61
  this.data = data;
47
62
  }
48
63
  }
64
+ var nullWriter = {
65
+ write() {},
66
+ merge() {}
67
+ };
68
+ function composeEventMiddleware(middlewares, terminal) {
69
+ return (event) => {
70
+ function dispatch(i) {
71
+ if (i === middlewares.length) {
72
+ terminal(event);
73
+ return;
74
+ }
75
+ middlewares[i](event, () => dispatch(i + 1));
76
+ }
77
+ dispatch(0);
78
+ };
79
+ }
49
80
 
50
81
  class CompiledGraph {
51
82
  nodeRegistry;
52
83
  edgeRegistry;
53
84
  subgraphRegistry;
54
85
  storage;
55
- emitter = new NodeEventEmitter;
56
86
  stateManager = new StateManager;
57
- onFinish;
58
- onStart;
59
- constructor(nodeRegistry, edgeRegistry, subgraphRegistry, options = {}) {
87
+ graphMiddleware;
88
+ nodeMiddleware;
89
+ stateMiddleware;
90
+ eventMiddleware;
91
+ constructor(nodeRegistry, edgeRegistry, subgraphRegistry, options = {}, graphMiddleware = [], nodeMiddleware = [], stateMiddleware = [], eventMiddleware = []) {
60
92
  this.nodeRegistry = nodeRegistry;
61
93
  this.edgeRegistry = edgeRegistry;
62
94
  this.subgraphRegistry = subgraphRegistry;
63
95
  this.storage = options.storage ?? new InMemoryStorage;
64
- this.onFinish = options.onFinish ?? (() => {});
65
- this.onStart = options.onStart ?? (() => {});
96
+ this.graphMiddleware = graphMiddleware;
97
+ this.nodeMiddleware = nodeMiddleware;
98
+ this.stateMiddleware = stateMiddleware;
99
+ this.eventMiddleware = eventMiddleware;
66
100
  }
67
- execute(runId, initialState) {
101
+ stream(runId, initialState) {
68
102
  let context;
69
103
  return createUIMessageStream({
70
104
  execute: async ({ writer }) => {
71
- const result = await this.createExecutionContext(runId, initialState, writer);
105
+ const emit = composeEventMiddleware(this.eventMiddleware, (event) => {
106
+ switch (event.type) {
107
+ case "state":
108
+ writer.write({ type: "data-state", data: event.state });
109
+ break;
110
+ case "node:start":
111
+ writer.write({ type: "data-node-start", data: event.nodeId });
112
+ break;
113
+ case "node:end":
114
+ writer.write({ type: "data-node-end", data: event.nodeId });
115
+ break;
116
+ case "node:suspense":
117
+ writer.write({ type: "data-node-suspense", data: { nodeId: event.nodeId, data: event.data } });
118
+ break;
119
+ }
120
+ });
121
+ const result = await this.createExecutionContext(runId, initialState, emit, writer);
72
122
  context = result.context;
73
123
  const firstTime = result.firstTime;
74
- if (firstTime) {
75
- await this.onStart({ state: context.state, writer });
76
- }
77
- await this.runExecutionLoop(context);
78
- },
79
- onFinish: async () => {
80
- if (context) {
81
- await this.onFinish({ state: context.state });
124
+ if (this.graphMiddleware.length > 0) {
125
+ const graphCtx = {
126
+ runId: context.runId,
127
+ state: () => context.state,
128
+ writer: context.writer,
129
+ isResume: !firstTime
130
+ };
131
+ await composeMiddleware(this.graphMiddleware, async () => {
132
+ await this.runExecutionLoop(context);
133
+ })(graphCtx);
134
+ } else {
135
+ await this.runExecutionLoop(context);
82
136
  }
83
137
  }
84
138
  });
85
139
  }
86
- async executeInternal(runId, initialState, writer) {
87
- const { context } = await this.createExecutionContext(runId, initialState, writer);
140
+ async execute(runId, initialState, options) {
141
+ const emit = composeEventMiddleware(this.eventMiddleware, options?.onEvent ?? (() => {}));
142
+ const { context, firstTime } = await this.createExecutionContext(runId, initialState, emit, nullWriter);
143
+ if (this.graphMiddleware.length > 0) {
144
+ const graphCtx = {
145
+ runId: context.runId,
146
+ state: () => context.state,
147
+ writer: context.writer,
148
+ isResume: !firstTime
149
+ };
150
+ await composeMiddleware(this.graphMiddleware, async () => {
151
+ await this.runExecutionLoop(context);
152
+ })(graphCtx);
153
+ } else {
154
+ await this.runExecutionLoop(context);
155
+ }
156
+ return context.state;
157
+ }
158
+ async executeInternal(runId, initialState, writer, emit) {
159
+ const { context } = await this.createExecutionContext(runId, initialState, emit, writer);
88
160
  await this.runExecutionLoopInternal(context);
89
161
  return context.state;
90
162
  }
91
- async createExecutionContext(runId, initialState, writer) {
92
- const { context, firstTime } = await this.restoreCheckpoint(runId, initialState, writer);
93
- return { context: { ...context, runId, writer }, firstTime };
163
+ async createExecutionContext(runId, initialState, emit, writer) {
164
+ const { context, firstTime } = await this.restoreCheckpoint(runId, initialState, emit);
165
+ return { context: { ...context, runId, writer, emit }, firstTime };
94
166
  }
95
167
  async runExecutionLoop(context) {
96
168
  await this.executeWithStrategy(context, "return");
@@ -139,12 +211,12 @@ class CompiledGraph {
139
211
  hasNodesToExecute(context) {
140
212
  return context.currentNodes.length > 0;
141
213
  }
142
- async restoreCheckpoint(runId, initialState, writer) {
214
+ async restoreCheckpoint(runId, initialState, emit) {
143
215
  const checkpoint = await this.storage.load(runId);
144
216
  if (this.isValidCheckpoint(checkpoint)) {
145
- return { context: this.restoreFromCheckpoint(checkpoint, initialState, writer), firstTime: false };
217
+ return { context: this.restoreFromCheckpoint(checkpoint, initialState, emit), firstTime: false };
146
218
  }
147
- return { context: this.createFreshExecution(initialState, writer), firstTime: true };
219
+ return { context: this.createFreshExecution(initialState, emit), firstTime: true };
148
220
  }
149
221
  isValidCheckpoint(checkpoint) {
150
222
  return this.hasNodeIds(checkpoint) && this.hasAtLeastOneNode(checkpoint);
@@ -155,16 +227,16 @@ class CompiledGraph {
155
227
  hasAtLeastOneNode(checkpoint) {
156
228
  return checkpoint.nodeIds.length > 0;
157
229
  }
158
- restoreFromCheckpoint(checkpoint, initialState, writer) {
159
- const state = this.stateManager.resolve(initialState, checkpoint.state, writer);
230
+ restoreFromCheckpoint(checkpoint, initialState, emit) {
231
+ const state = this.stateManager.resolve(initialState, checkpoint.state, emit);
160
232
  return {
161
233
  state,
162
234
  currentNodes: this.resolveNodeIds(checkpoint.nodeIds),
163
235
  suspendedNodes: this.resolveNodeIds(checkpoint.suspendedNodes)
164
236
  };
165
237
  }
166
- createFreshExecution(initialState, writer) {
167
- const state = this.stateManager.resolve(initialState, undefined, writer);
238
+ createFreshExecution(initialState, emit) {
239
+ const state = this.stateManager.resolve(initialState, undefined, emit);
168
240
  const startNode = this.nodeRegistry.get(BUILT_IN_NODES.START);
169
241
  return {
170
242
  state,
@@ -188,43 +260,110 @@ class CompiledGraph {
188
260
  if (subgraphEntry) {
189
261
  return this.executeSubgraphNode(node, context, subgraphEntry);
190
262
  }
191
- this.emitter.emitStart(context.writer, node.id);
263
+ const isBuiltIn = node.id === BUILT_IN_NODES.START || node.id === BUILT_IN_NODES.END;
264
+ context.emit({ type: "node:start", nodeId: node.id });
192
265
  try {
193
- await node.execute(this.createNodeExecutionParams(context));
194
- this.emitter.emitEnd(context.writer, node.id);
266
+ const { params, pendingUpdates } = this.createNodeExecutionParams(context, node.id);
267
+ if (!isBuiltIn && this.nodeMiddleware.length > 0) {
268
+ const nodeCtx = {
269
+ runId: context.runId,
270
+ nodeId: node.id,
271
+ state: () => context.state,
272
+ writer: context.writer,
273
+ isSubgraph: false
274
+ };
275
+ await composeMiddleware(this.nodeMiddleware, async () => {
276
+ await node.execute(params);
277
+ })(nodeCtx);
278
+ } else {
279
+ await node.execute(params);
280
+ }
281
+ await Promise.all(pendingUpdates);
282
+ context.emit({ type: "node:end", nodeId: node.id });
195
283
  return null;
196
284
  } catch (error) {
197
285
  if (error instanceof SuspenseError) {
198
- this.emitter.emitSuspense(context.writer, node.id, error.data);
286
+ context.emit({ type: "node:suspense", nodeId: node.id, data: error.data });
199
287
  return { node, error };
200
288
  }
201
289
  throw error;
202
290
  }
203
291
  }
204
- createNodeExecutionParams(context) {
205
- return {
292
+ createNodeExecutionParams(context, nodeId = null) {
293
+ const pendingUpdates = [];
294
+ const params = {
206
295
  state: () => context.state,
207
296
  writer: context.writer,
208
297
  suspense: this.createSuspenseFunction(),
209
298
  update: (update) => {
210
- context.state = this.stateManager.apply(context.state, update, context.writer);
299
+ const p = (async () => {
300
+ if (this.stateMiddleware.length > 0) {
301
+ const resolvedUpdate = typeof update === "function" ? update(context.state) : update;
302
+ const stateCtx = {
303
+ runId: context.runId,
304
+ nodeId,
305
+ currentState: context.state,
306
+ update,
307
+ resolvedUpdate
308
+ };
309
+ const finalPartial = await composeMiddleware(this.stateMiddleware, async (ctx) => ctx.resolvedUpdate)(stateCtx);
310
+ context.state = { ...context.state, ...finalPartial };
311
+ context.emit({ type: "state", state: context.state });
312
+ } else {
313
+ context.state = this.stateManager.apply(context.state, update);
314
+ context.emit({ type: "state", state: context.state });
315
+ }
316
+ })();
317
+ pendingUpdates.push(p);
318
+ return p;
211
319
  }
212
320
  };
321
+ return { params, pendingUpdates };
213
322
  }
214
323
  async executeSubgraphNode(node, context, entry) {
215
324
  const { subgraph, options } = entry;
216
- this.emitter.emitStart(context.writer, node.id);
325
+ context.emit({ type: "node:start", nodeId: node.id });
217
326
  const subgraphRunId = this.generateSubgraphRunId(context.runId, node.id);
218
327
  try {
219
- const childRunner = new CompiledGraph(subgraph.nodes, subgraph.edges, subgraph.subgraphs, { storage: this.storage });
220
- const childFinalState = await childRunner.executeInternal(subgraphRunId, options.input(context.state), context.writer);
221
- const parentUpdate = options.output(childFinalState, context.state);
222
- context.state = this.stateManager.apply(context.state, parentUpdate, context.writer);
223
- this.emitter.emitEnd(context.writer, node.id);
328
+ const executeSubgraph = async () => {
329
+ const childRunner = new CompiledGraph(subgraph.nodes, subgraph.edges, subgraph.subgraphs, { storage: this.storage }, [], this.nodeMiddleware, this.stateMiddleware, this.eventMiddleware);
330
+ const childFinalState = await childRunner.executeInternal(subgraphRunId, options.input(context.state), context.writer, context.emit);
331
+ const parentUpdate = options.output(childFinalState, context.state);
332
+ if (this.stateMiddleware.length > 0) {
333
+ const resolvedUpdate = typeof parentUpdate === "function" ? parentUpdate(context.state) : parentUpdate;
334
+ const stateCtx = {
335
+ runId: context.runId,
336
+ nodeId: node.id,
337
+ currentState: context.state,
338
+ update: parentUpdate,
339
+ resolvedUpdate
340
+ };
341
+ const finalPartial = await composeMiddleware(this.stateMiddleware, async (ctx) => ctx.resolvedUpdate)(stateCtx);
342
+ context.state = { ...context.state, ...finalPartial };
343
+ } else {
344
+ context.state = this.stateManager.apply(context.state, parentUpdate);
345
+ }
346
+ context.emit({ type: "state", state: context.state });
347
+ };
348
+ if (this.nodeMiddleware.length > 0) {
349
+ const nodeCtx = {
350
+ runId: context.runId,
351
+ nodeId: node.id,
352
+ state: () => context.state,
353
+ writer: context.writer,
354
+ isSubgraph: true
355
+ };
356
+ await composeMiddleware(this.nodeMiddleware, async () => {
357
+ await executeSubgraph();
358
+ })(nodeCtx);
359
+ } else {
360
+ await executeSubgraph();
361
+ }
362
+ context.emit({ type: "node:end", nodeId: node.id });
224
363
  return null;
225
364
  } catch (error) {
226
365
  if (error instanceof SuspenseError) {
227
- this.emitter.emitSuspense(context.writer, node.id, error.data);
366
+ context.emit({ type: "node:suspense", nodeId: node.id, data: error.data });
228
367
  return { node, error };
229
368
  }
230
369
  throw error;
@@ -262,35 +401,21 @@ class CompiledGraph {
262
401
  }
263
402
  }
264
403
 
265
- class NodeEventEmitter {
266
- emitStart(writer, nodeId) {
267
- writer.write({ type: "data-node-start", data: nodeId });
268
- }
269
- emitEnd(writer, nodeId) {
270
- writer.write({ type: "data-node-end", data: nodeId });
271
- }
272
- emitSuspense(writer, nodeId, data) {
273
- writer.write({ type: "data-node-suspense", data: { nodeId, data } });
274
- }
275
- }
276
-
277
404
  class StateManager {
278
- apply(state, update, writer) {
279
- const newState = {
405
+ apply(state, update) {
406
+ return {
280
407
  ...state,
281
408
  ...typeof update === "function" ? update(state) : update
282
409
  };
283
- writer.write({ type: "data-state", data: newState });
284
- return newState;
285
410
  }
286
- resolve(initialState, existingState, writer) {
411
+ resolve(initialState, existingState, emit) {
287
412
  if (this.isStateFactory(initialState)) {
288
413
  const newState2 = initialState(existingState);
289
- writer.write({ type: "data-state", data: newState2 });
414
+ emit({ type: "state", state: newState2 });
290
415
  return newState2;
291
416
  }
292
417
  const newState = existingState ?? initialState;
293
- writer.write({ type: "data-state", data: newState });
418
+ emit({ type: "state", state: newState });
294
419
  return newState;
295
420
  }
296
421
  isStateFactory(initialState) {
@@ -303,6 +428,7 @@ class Graph {
303
428
  nodeRegistry = new Map;
304
429
  edgeRegistry = new Map;
305
430
  subgraphRegistry = new Map;
431
+ middlewares = [];
306
432
  constructor() {
307
433
  this.registerBuiltInNodes();
308
434
  }
@@ -334,8 +460,26 @@ class Graph {
334
460
  get subgraphs() {
335
461
  return this.subgraphRegistry;
336
462
  }
463
+ use(middleware) {
464
+ this.middlewares.push(middleware);
465
+ return this;
466
+ }
337
467
  compile(options = {}) {
338
- return new CompiledGraph(this.nodeRegistry, this.edgeRegistry, this.subgraphRegistry, options);
468
+ const graphMiddleware = [];
469
+ const nodeMiddleware = [];
470
+ const stateMiddleware = [];
471
+ const eventMiddleware = [];
472
+ for (const mw of this.middlewares) {
473
+ if (mw.graph)
474
+ graphMiddleware.push(mw.graph);
475
+ if (mw.node)
476
+ nodeMiddleware.push(mw.node);
477
+ if (mw.state)
478
+ stateMiddleware.push(mw.state);
479
+ if (mw.event)
480
+ eventMiddleware.push(mw.event);
481
+ }
482
+ return new CompiledGraph(this.nodeRegistry, this.edgeRegistry, this.subgraphRegistry, options, graphMiddleware, nodeMiddleware, stateMiddleware, eventMiddleware);
339
483
  }
340
484
  toMermaid(options) {
341
485
  const generator = new MermaidGenerator(this.nodeRegistry, this.edgeRegistry, this.subgraphRegistry);
@@ -468,6 +612,7 @@ export {
468
612
  isGraphDataPart,
469
613
  graph,
470
614
  consumeAndMergeStream,
615
+ composeMiddleware,
471
616
  SuspenseError,
472
617
  RedisStorage,
473
618
  InMemoryStorage,
@@ -0,0 +1 @@
1
+ export declare function composeMiddleware<Ctx, R = void>(middlewares: Array<(ctx: Ctx, next: () => Promise<R>) => Promise<R>>, action: (ctx: Ctx) => Promise<R>): (ctx: Ctx) => Promise<R>;
package/dist/types.d.ts CHANGED
@@ -3,7 +3,7 @@ export declare namespace GraphSDK {
3
3
  type StateUpdate<State> = Partial<State> | ((state: State) => Partial<State>);
4
4
  interface SubgraphOptions<ParentState extends Record<string, unknown>, ChildState extends Record<string, unknown>> {
5
5
  input: (parentState: ParentState) => ChildState;
6
- output: (childState: ChildState, parentState: ParentState) => Partial<ParentState>;
6
+ output: (childState: ChildState, parentState: ParentState) => StateUpdate<ParentState>;
7
7
  }
8
8
  interface Graph<State extends Record<string, unknown>, NodeKeys extends string> {
9
9
  nodes: Map<NodeKeys, Node<State, NodeKeys>>;
@@ -15,7 +15,7 @@ export declare namespace GraphSDK {
15
15
  state: () => Readonly<State>;
16
16
  writer: UIMessageStreamWriter;
17
17
  suspense: (data?: unknown) => never;
18
- update: (update: StateUpdate<State>) => void;
18
+ update: (update: StateUpdate<State>) => Promise<void>;
19
19
  }) => Promise<void> | void;
20
20
  }
21
21
  interface Edge<State extends Record<string, unknown>, NodeKeys extends string> {
@@ -37,27 +37,58 @@ export declare namespace GraphSDK {
37
37
  state: State;
38
38
  currentNodes: Node<State, NodeKeys>[];
39
39
  suspendedNodes: Node<State, NodeKeys>[];
40
- writer: UIMessageStreamWriter;
40
+ writer: Writer;
41
+ emit: (event: GraphEvent<State, NodeKeys>) => void;
41
42
  }
42
43
  interface GraphOptions<State extends Record<string, unknown>, NodeKeys extends string> {
43
44
  storage?: GraphStorage<State, NodeKeys>;
44
- onFinish?: (args: {
45
- state: State;
46
- }) => Promise<void> | void;
47
- onStart?: (args: {
48
- state: State;
49
- writer: UIMessageStreamWriter;
50
- }) => Promise<void> | void;
51
45
  }
52
46
  interface CompileOptions<State extends Record<string, unknown>, NodeKeys extends string> {
53
47
  storage?: GraphStorage<State, NodeKeys>;
54
- onFinish?: (args: {
55
- state: State;
56
- }) => Promise<void> | void;
57
- onStart?: (args: {
58
- state: State;
59
- writer: UIMessageStreamWriter;
60
- }) => Promise<void> | void;
61
48
  }
62
49
  type Writer = Parameters<Parameters<typeof createUIMessageStream>[0]['execute']>[0]['writer'];
50
+ interface GraphMiddlewareContext<State extends Record<string, unknown>, NodeKeys extends string> {
51
+ runId: string;
52
+ state: () => Readonly<State>;
53
+ writer: Writer;
54
+ isResume: boolean;
55
+ }
56
+ interface NodeMiddlewareContext<State extends Record<string, unknown>, NodeKeys extends string> {
57
+ runId: string;
58
+ nodeId: NodeKeys;
59
+ state: () => Readonly<State>;
60
+ writer: Writer;
61
+ isSubgraph: boolean;
62
+ }
63
+ interface StateMiddlewareContext<State extends Record<string, unknown>, NodeKeys extends string> {
64
+ runId: string;
65
+ nodeId: NodeKeys | null;
66
+ currentState: Readonly<State>;
67
+ update: StateUpdate<State>;
68
+ resolvedUpdate: Partial<State>;
69
+ }
70
+ type GraphMiddleware<State extends Record<string, unknown>, NodeKeys extends string> = (ctx: GraphMiddlewareContext<State, NodeKeys>, next: () => Promise<void>) => Promise<void>;
71
+ type NodeMiddleware<State extends Record<string, unknown>, NodeKeys extends string> = (ctx: NodeMiddlewareContext<State, NodeKeys>, next: () => Promise<void>) => Promise<void>;
72
+ type StateMiddleware<State extends Record<string, unknown>, NodeKeys extends string> = (ctx: StateMiddlewareContext<State, NodeKeys>, next: () => Promise<Partial<State>>) => Promise<Partial<State>>;
73
+ type GraphEvent<State extends Record<string, unknown>, NodeKeys extends string> = {
74
+ type: 'state';
75
+ state: State;
76
+ } | {
77
+ type: 'node:start';
78
+ nodeId: NodeKeys;
79
+ } | {
80
+ type: 'node:end';
81
+ nodeId: NodeKeys;
82
+ } | {
83
+ type: 'node:suspense';
84
+ nodeId: NodeKeys;
85
+ data: unknown;
86
+ };
87
+ type EventMiddleware<State extends Record<string, unknown>, NodeKeys extends string> = (event: GraphEvent<State, NodeKeys>, next: () => void) => void;
88
+ interface Middleware<State extends Record<string, unknown>, NodeKeys extends string> {
89
+ graph?: GraphMiddleware<State, NodeKeys>;
90
+ node?: NodeMiddleware<State, NodeKeys>;
91
+ state?: StateMiddleware<State, NodeKeys>;
92
+ event?: EventMiddleware<State, NodeKeys>;
93
+ }
63
94
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "ai-sdk-graph",
3
- "version": "0.5.0",
3
+ "version": "0.6.0",
4
4
  "description": "Graph-based workflows for the AI SDK",
5
5
  "type": "module",
6
6
  "main": "dist/index.js",