ai-sdk-graph 0.5.0 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/compiled-graph.d.ts +10 -6
- package/dist/graph.d.ts +3 -1
- package/dist/index.d.ts +1 -0
- package/dist/index.js +207 -62
- package/dist/middleware.d.ts +1 -0
- package/dist/types.d.ts +48 -17
- package/package.json +1 -1
package/dist/compiled-graph.d.ts
CHANGED
|
@@ -9,16 +9,20 @@ export declare class CompiledGraph<State extends Record<string, unknown>, NodeKe
|
|
|
9
9
|
private readonly edgeRegistry;
|
|
10
10
|
private readonly subgraphRegistry;
|
|
11
11
|
private readonly storage;
|
|
12
|
-
private readonly emitter;
|
|
13
12
|
private readonly stateManager;
|
|
14
|
-
private readonly
|
|
15
|
-
private readonly
|
|
13
|
+
private readonly graphMiddleware;
|
|
14
|
+
private readonly nodeMiddleware;
|
|
15
|
+
private readonly stateMiddleware;
|
|
16
|
+
private readonly eventMiddleware;
|
|
16
17
|
constructor(nodeRegistry: ReadonlyMap<NodeKeys, GraphSDK.Node<State, NodeKeys>>, edgeRegistry: ReadonlyMap<NodeKeys, GraphSDK.Edge<State, NodeKeys>[]>, subgraphRegistry: ReadonlyMap<NodeKeys, {
|
|
17
18
|
subgraph: Graph<any, any>;
|
|
18
19
|
options: GraphSDK.SubgraphOptions<State, any>;
|
|
19
|
-
}>, options?: GraphSDK.CompileOptions<State, NodeKeys>);
|
|
20
|
-
|
|
21
|
-
|
|
20
|
+
}>, options?: GraphSDK.CompileOptions<State, NodeKeys>, graphMiddleware?: GraphSDK.GraphMiddleware<State, NodeKeys>[], nodeMiddleware?: GraphSDK.NodeMiddleware<State, NodeKeys>[], stateMiddleware?: GraphSDK.StateMiddleware<State, NodeKeys>[], eventMiddleware?: GraphSDK.EventMiddleware<State, NodeKeys>[]);
|
|
21
|
+
stream(runId: string, initialState: State | ((state: State | undefined) => State)): ReadableStream<import("ai").InferUIMessageChunk<import("ai").UIMessage<unknown, import("ai").UIDataTypes, import("ai").UITools>>>;
|
|
22
|
+
execute(runId: string, initialState: State | ((state: State | undefined) => State), options?: {
|
|
23
|
+
onEvent?: (event: GraphSDK.GraphEvent<State, NodeKeys>) => void;
|
|
24
|
+
}): Promise<State>;
|
|
25
|
+
executeInternal(runId: string, initialState: State, writer: GraphSDK.Writer, emit: (event: GraphSDK.GraphEvent<State, NodeKeys>) => void): Promise<State>;
|
|
22
26
|
private createExecutionContext;
|
|
23
27
|
private runExecutionLoop;
|
|
24
28
|
private runExecutionLoopInternal;
|
package/dist/graph.d.ts
CHANGED
|
@@ -5,12 +5,13 @@ export declare class Graph<State extends Record<string, unknown>, NodeKeys exten
|
|
|
5
5
|
private readonly nodeRegistry;
|
|
6
6
|
private readonly edgeRegistry;
|
|
7
7
|
private readonly subgraphRegistry;
|
|
8
|
+
private readonly middlewares;
|
|
8
9
|
constructor();
|
|
9
10
|
node<NewKey extends string>(id: NewKey, execute: ({ state, writer, suspense, update }: {
|
|
10
11
|
state: () => Readonly<State>;
|
|
11
12
|
writer: GraphSDK.Writer;
|
|
12
13
|
suspense: (data?: unknown) => never;
|
|
13
|
-
update: (update: GraphSDK.StateUpdate<State>) => void
|
|
14
|
+
update: (update: GraphSDK.StateUpdate<State>) => Promise<void>;
|
|
14
15
|
}) => Promise<void> | void): Graph<State, NodeKeys | NewKey>;
|
|
15
16
|
edge(from: NodeKeys, to: NodeKeys | ((state: State) => NodeKeys)): Graph<State, NodeKeys>;
|
|
16
17
|
graph<NewKey extends string, ChildState extends Record<string, unknown>, ChildNodeKeys extends string = 'START' | 'END'>(id: NewKey, subgraph: Graph<ChildState, ChildNodeKeys>, options: GraphSDK.SubgraphOptions<State, ChildState>): Graph<State, NodeKeys | NewKey>;
|
|
@@ -20,6 +21,7 @@ export declare class Graph<State extends Record<string, unknown>, NodeKeys exten
|
|
|
20
21
|
subgraph: Graph<any, any>;
|
|
21
22
|
options: GraphSDK.SubgraphOptions<State, any>;
|
|
22
23
|
}>;
|
|
24
|
+
use(middleware: GraphSDK.Middleware<State, NodeKeys>): Graph<State, NodeKeys>;
|
|
23
25
|
compile(options?: GraphSDK.CompileOptions<State, NodeKeys>): CompiledGraph<State, NodeKeys>;
|
|
24
26
|
toMermaid(options?: {
|
|
25
27
|
direction?: 'TB' | 'LR';
|
package/dist/index.d.ts
CHANGED
package/dist/index.js
CHANGED
|
@@ -33,6 +33,21 @@ class RedisStorage {
|
|
|
33
33
|
|
|
34
34
|
// src/compiled-graph.ts
|
|
35
35
|
import { createUIMessageStream } from "ai";
|
|
36
|
+
|
|
37
|
+
// src/middleware.ts
|
|
38
|
+
function composeMiddleware(middlewares, action) {
|
|
39
|
+
return (ctx) => {
|
|
40
|
+
function dispatch(i) {
|
|
41
|
+
if (i === middlewares.length) {
|
|
42
|
+
return action(ctx);
|
|
43
|
+
}
|
|
44
|
+
return middlewares[i](ctx, () => dispatch(i + 1));
|
|
45
|
+
}
|
|
46
|
+
return dispatch(0);
|
|
47
|
+
};
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
// src/compiled-graph.ts
|
|
36
51
|
var BUILT_IN_NODES = {
|
|
37
52
|
START: "START",
|
|
38
53
|
END: "END"
|
|
@@ -46,51 +61,108 @@ class SuspenseError extends Error {
|
|
|
46
61
|
this.data = data;
|
|
47
62
|
}
|
|
48
63
|
}
|
|
64
|
+
var nullWriter = {
|
|
65
|
+
write() {},
|
|
66
|
+
merge() {}
|
|
67
|
+
};
|
|
68
|
+
function composeEventMiddleware(middlewares, terminal) {
|
|
69
|
+
return (event) => {
|
|
70
|
+
function dispatch(i) {
|
|
71
|
+
if (i === middlewares.length) {
|
|
72
|
+
terminal(event);
|
|
73
|
+
return;
|
|
74
|
+
}
|
|
75
|
+
middlewares[i](event, () => dispatch(i + 1));
|
|
76
|
+
}
|
|
77
|
+
dispatch(0);
|
|
78
|
+
};
|
|
79
|
+
}
|
|
49
80
|
|
|
50
81
|
class CompiledGraph {
|
|
51
82
|
nodeRegistry;
|
|
52
83
|
edgeRegistry;
|
|
53
84
|
subgraphRegistry;
|
|
54
85
|
storage;
|
|
55
|
-
emitter = new NodeEventEmitter;
|
|
56
86
|
stateManager = new StateManager;
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
87
|
+
graphMiddleware;
|
|
88
|
+
nodeMiddleware;
|
|
89
|
+
stateMiddleware;
|
|
90
|
+
eventMiddleware;
|
|
91
|
+
constructor(nodeRegistry, edgeRegistry, subgraphRegistry, options = {}, graphMiddleware = [], nodeMiddleware = [], stateMiddleware = [], eventMiddleware = []) {
|
|
60
92
|
this.nodeRegistry = nodeRegistry;
|
|
61
93
|
this.edgeRegistry = edgeRegistry;
|
|
62
94
|
this.subgraphRegistry = subgraphRegistry;
|
|
63
95
|
this.storage = options.storage ?? new InMemoryStorage;
|
|
64
|
-
this.
|
|
65
|
-
this.
|
|
96
|
+
this.graphMiddleware = graphMiddleware;
|
|
97
|
+
this.nodeMiddleware = nodeMiddleware;
|
|
98
|
+
this.stateMiddleware = stateMiddleware;
|
|
99
|
+
this.eventMiddleware = eventMiddleware;
|
|
66
100
|
}
|
|
67
|
-
|
|
101
|
+
stream(runId, initialState) {
|
|
68
102
|
let context;
|
|
69
103
|
return createUIMessageStream({
|
|
70
104
|
execute: async ({ writer }) => {
|
|
71
|
-
const
|
|
105
|
+
const emit = composeEventMiddleware(this.eventMiddleware, (event) => {
|
|
106
|
+
switch (event.type) {
|
|
107
|
+
case "state":
|
|
108
|
+
writer.write({ type: "data-state", data: event.state });
|
|
109
|
+
break;
|
|
110
|
+
case "node:start":
|
|
111
|
+
writer.write({ type: "data-node-start", data: event.nodeId });
|
|
112
|
+
break;
|
|
113
|
+
case "node:end":
|
|
114
|
+
writer.write({ type: "data-node-end", data: event.nodeId });
|
|
115
|
+
break;
|
|
116
|
+
case "node:suspense":
|
|
117
|
+
writer.write({ type: "data-node-suspense", data: { nodeId: event.nodeId, data: event.data } });
|
|
118
|
+
break;
|
|
119
|
+
}
|
|
120
|
+
});
|
|
121
|
+
const result = await this.createExecutionContext(runId, initialState, emit, writer);
|
|
72
122
|
context = result.context;
|
|
73
123
|
const firstTime = result.firstTime;
|
|
74
|
-
if (
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
await this.
|
|
124
|
+
if (this.graphMiddleware.length > 0) {
|
|
125
|
+
const graphCtx = {
|
|
126
|
+
runId: context.runId,
|
|
127
|
+
state: () => context.state,
|
|
128
|
+
writer: context.writer,
|
|
129
|
+
isResume: !firstTime
|
|
130
|
+
};
|
|
131
|
+
await composeMiddleware(this.graphMiddleware, async () => {
|
|
132
|
+
await this.runExecutionLoop(context);
|
|
133
|
+
})(graphCtx);
|
|
134
|
+
} else {
|
|
135
|
+
await this.runExecutionLoop(context);
|
|
82
136
|
}
|
|
83
137
|
}
|
|
84
138
|
});
|
|
85
139
|
}
|
|
86
|
-
async
|
|
87
|
-
const
|
|
140
|
+
async execute(runId, initialState, options) {
|
|
141
|
+
const emit = composeEventMiddleware(this.eventMiddleware, options?.onEvent ?? (() => {}));
|
|
142
|
+
const { context, firstTime } = await this.createExecutionContext(runId, initialState, emit, nullWriter);
|
|
143
|
+
if (this.graphMiddleware.length > 0) {
|
|
144
|
+
const graphCtx = {
|
|
145
|
+
runId: context.runId,
|
|
146
|
+
state: () => context.state,
|
|
147
|
+
writer: context.writer,
|
|
148
|
+
isResume: !firstTime
|
|
149
|
+
};
|
|
150
|
+
await composeMiddleware(this.graphMiddleware, async () => {
|
|
151
|
+
await this.runExecutionLoop(context);
|
|
152
|
+
})(graphCtx);
|
|
153
|
+
} else {
|
|
154
|
+
await this.runExecutionLoop(context);
|
|
155
|
+
}
|
|
156
|
+
return context.state;
|
|
157
|
+
}
|
|
158
|
+
async executeInternal(runId, initialState, writer, emit) {
|
|
159
|
+
const { context } = await this.createExecutionContext(runId, initialState, emit, writer);
|
|
88
160
|
await this.runExecutionLoopInternal(context);
|
|
89
161
|
return context.state;
|
|
90
162
|
}
|
|
91
|
-
async createExecutionContext(runId, initialState, writer) {
|
|
92
|
-
const { context, firstTime } = await this.restoreCheckpoint(runId, initialState,
|
|
93
|
-
return { context: { ...context, runId, writer }, firstTime };
|
|
163
|
+
async createExecutionContext(runId, initialState, emit, writer) {
|
|
164
|
+
const { context, firstTime } = await this.restoreCheckpoint(runId, initialState, emit);
|
|
165
|
+
return { context: { ...context, runId, writer, emit }, firstTime };
|
|
94
166
|
}
|
|
95
167
|
async runExecutionLoop(context) {
|
|
96
168
|
await this.executeWithStrategy(context, "return");
|
|
@@ -139,12 +211,12 @@ class CompiledGraph {
|
|
|
139
211
|
hasNodesToExecute(context) {
|
|
140
212
|
return context.currentNodes.length > 0;
|
|
141
213
|
}
|
|
142
|
-
async restoreCheckpoint(runId, initialState,
|
|
214
|
+
async restoreCheckpoint(runId, initialState, emit) {
|
|
143
215
|
const checkpoint = await this.storage.load(runId);
|
|
144
216
|
if (this.isValidCheckpoint(checkpoint)) {
|
|
145
|
-
return { context: this.restoreFromCheckpoint(checkpoint, initialState,
|
|
217
|
+
return { context: this.restoreFromCheckpoint(checkpoint, initialState, emit), firstTime: false };
|
|
146
218
|
}
|
|
147
|
-
return { context: this.createFreshExecution(initialState,
|
|
219
|
+
return { context: this.createFreshExecution(initialState, emit), firstTime: true };
|
|
148
220
|
}
|
|
149
221
|
isValidCheckpoint(checkpoint) {
|
|
150
222
|
return this.hasNodeIds(checkpoint) && this.hasAtLeastOneNode(checkpoint);
|
|
@@ -155,16 +227,16 @@ class CompiledGraph {
|
|
|
155
227
|
hasAtLeastOneNode(checkpoint) {
|
|
156
228
|
return checkpoint.nodeIds.length > 0;
|
|
157
229
|
}
|
|
158
|
-
restoreFromCheckpoint(checkpoint, initialState,
|
|
159
|
-
const state = this.stateManager.resolve(initialState, checkpoint.state,
|
|
230
|
+
restoreFromCheckpoint(checkpoint, initialState, emit) {
|
|
231
|
+
const state = this.stateManager.resolve(initialState, checkpoint.state, emit);
|
|
160
232
|
return {
|
|
161
233
|
state,
|
|
162
234
|
currentNodes: this.resolveNodeIds(checkpoint.nodeIds),
|
|
163
235
|
suspendedNodes: this.resolveNodeIds(checkpoint.suspendedNodes)
|
|
164
236
|
};
|
|
165
237
|
}
|
|
166
|
-
createFreshExecution(initialState,
|
|
167
|
-
const state = this.stateManager.resolve(initialState, undefined,
|
|
238
|
+
createFreshExecution(initialState, emit) {
|
|
239
|
+
const state = this.stateManager.resolve(initialState, undefined, emit);
|
|
168
240
|
const startNode = this.nodeRegistry.get(BUILT_IN_NODES.START);
|
|
169
241
|
return {
|
|
170
242
|
state,
|
|
@@ -188,43 +260,110 @@ class CompiledGraph {
|
|
|
188
260
|
if (subgraphEntry) {
|
|
189
261
|
return this.executeSubgraphNode(node, context, subgraphEntry);
|
|
190
262
|
}
|
|
191
|
-
|
|
263
|
+
const isBuiltIn = node.id === BUILT_IN_NODES.START || node.id === BUILT_IN_NODES.END;
|
|
264
|
+
context.emit({ type: "node:start", nodeId: node.id });
|
|
192
265
|
try {
|
|
193
|
-
|
|
194
|
-
this.
|
|
266
|
+
const { params, pendingUpdates } = this.createNodeExecutionParams(context, node.id);
|
|
267
|
+
if (!isBuiltIn && this.nodeMiddleware.length > 0) {
|
|
268
|
+
const nodeCtx = {
|
|
269
|
+
runId: context.runId,
|
|
270
|
+
nodeId: node.id,
|
|
271
|
+
state: () => context.state,
|
|
272
|
+
writer: context.writer,
|
|
273
|
+
isSubgraph: false
|
|
274
|
+
};
|
|
275
|
+
await composeMiddleware(this.nodeMiddleware, async () => {
|
|
276
|
+
await node.execute(params);
|
|
277
|
+
})(nodeCtx);
|
|
278
|
+
} else {
|
|
279
|
+
await node.execute(params);
|
|
280
|
+
}
|
|
281
|
+
await Promise.all(pendingUpdates);
|
|
282
|
+
context.emit({ type: "node:end", nodeId: node.id });
|
|
195
283
|
return null;
|
|
196
284
|
} catch (error) {
|
|
197
285
|
if (error instanceof SuspenseError) {
|
|
198
|
-
|
|
286
|
+
context.emit({ type: "node:suspense", nodeId: node.id, data: error.data });
|
|
199
287
|
return { node, error };
|
|
200
288
|
}
|
|
201
289
|
throw error;
|
|
202
290
|
}
|
|
203
291
|
}
|
|
204
|
-
createNodeExecutionParams(context) {
|
|
205
|
-
|
|
292
|
+
createNodeExecutionParams(context, nodeId = null) {
|
|
293
|
+
const pendingUpdates = [];
|
|
294
|
+
const params = {
|
|
206
295
|
state: () => context.state,
|
|
207
296
|
writer: context.writer,
|
|
208
297
|
suspense: this.createSuspenseFunction(),
|
|
209
298
|
update: (update) => {
|
|
210
|
-
|
|
299
|
+
const p = (async () => {
|
|
300
|
+
if (this.stateMiddleware.length > 0) {
|
|
301
|
+
const resolvedUpdate = typeof update === "function" ? update(context.state) : update;
|
|
302
|
+
const stateCtx = {
|
|
303
|
+
runId: context.runId,
|
|
304
|
+
nodeId,
|
|
305
|
+
currentState: context.state,
|
|
306
|
+
update,
|
|
307
|
+
resolvedUpdate
|
|
308
|
+
};
|
|
309
|
+
const finalPartial = await composeMiddleware(this.stateMiddleware, async (ctx) => ctx.resolvedUpdate)(stateCtx);
|
|
310
|
+
context.state = { ...context.state, ...finalPartial };
|
|
311
|
+
context.emit({ type: "state", state: context.state });
|
|
312
|
+
} else {
|
|
313
|
+
context.state = this.stateManager.apply(context.state, update);
|
|
314
|
+
context.emit({ type: "state", state: context.state });
|
|
315
|
+
}
|
|
316
|
+
})();
|
|
317
|
+
pendingUpdates.push(p);
|
|
318
|
+
return p;
|
|
211
319
|
}
|
|
212
320
|
};
|
|
321
|
+
return { params, pendingUpdates };
|
|
213
322
|
}
|
|
214
323
|
async executeSubgraphNode(node, context, entry) {
|
|
215
324
|
const { subgraph, options } = entry;
|
|
216
|
-
|
|
325
|
+
context.emit({ type: "node:start", nodeId: node.id });
|
|
217
326
|
const subgraphRunId = this.generateSubgraphRunId(context.runId, node.id);
|
|
218
327
|
try {
|
|
219
|
-
const
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
328
|
+
const executeSubgraph = async () => {
|
|
329
|
+
const childRunner = new CompiledGraph(subgraph.nodes, subgraph.edges, subgraph.subgraphs, { storage: this.storage }, [], this.nodeMiddleware, this.stateMiddleware, this.eventMiddleware);
|
|
330
|
+
const childFinalState = await childRunner.executeInternal(subgraphRunId, options.input(context.state), context.writer, context.emit);
|
|
331
|
+
const parentUpdate = options.output(childFinalState, context.state);
|
|
332
|
+
if (this.stateMiddleware.length > 0) {
|
|
333
|
+
const resolvedUpdate = typeof parentUpdate === "function" ? parentUpdate(context.state) : parentUpdate;
|
|
334
|
+
const stateCtx = {
|
|
335
|
+
runId: context.runId,
|
|
336
|
+
nodeId: node.id,
|
|
337
|
+
currentState: context.state,
|
|
338
|
+
update: parentUpdate,
|
|
339
|
+
resolvedUpdate
|
|
340
|
+
};
|
|
341
|
+
const finalPartial = await composeMiddleware(this.stateMiddleware, async (ctx) => ctx.resolvedUpdate)(stateCtx);
|
|
342
|
+
context.state = { ...context.state, ...finalPartial };
|
|
343
|
+
} else {
|
|
344
|
+
context.state = this.stateManager.apply(context.state, parentUpdate);
|
|
345
|
+
}
|
|
346
|
+
context.emit({ type: "state", state: context.state });
|
|
347
|
+
};
|
|
348
|
+
if (this.nodeMiddleware.length > 0) {
|
|
349
|
+
const nodeCtx = {
|
|
350
|
+
runId: context.runId,
|
|
351
|
+
nodeId: node.id,
|
|
352
|
+
state: () => context.state,
|
|
353
|
+
writer: context.writer,
|
|
354
|
+
isSubgraph: true
|
|
355
|
+
};
|
|
356
|
+
await composeMiddleware(this.nodeMiddleware, async () => {
|
|
357
|
+
await executeSubgraph();
|
|
358
|
+
})(nodeCtx);
|
|
359
|
+
} else {
|
|
360
|
+
await executeSubgraph();
|
|
361
|
+
}
|
|
362
|
+
context.emit({ type: "node:end", nodeId: node.id });
|
|
224
363
|
return null;
|
|
225
364
|
} catch (error) {
|
|
226
365
|
if (error instanceof SuspenseError) {
|
|
227
|
-
|
|
366
|
+
context.emit({ type: "node:suspense", nodeId: node.id, data: error.data });
|
|
228
367
|
return { node, error };
|
|
229
368
|
}
|
|
230
369
|
throw error;
|
|
@@ -262,35 +401,21 @@ class CompiledGraph {
|
|
|
262
401
|
}
|
|
263
402
|
}
|
|
264
403
|
|
|
265
|
-
class NodeEventEmitter {
|
|
266
|
-
emitStart(writer, nodeId) {
|
|
267
|
-
writer.write({ type: "data-node-start", data: nodeId });
|
|
268
|
-
}
|
|
269
|
-
emitEnd(writer, nodeId) {
|
|
270
|
-
writer.write({ type: "data-node-end", data: nodeId });
|
|
271
|
-
}
|
|
272
|
-
emitSuspense(writer, nodeId, data) {
|
|
273
|
-
writer.write({ type: "data-node-suspense", data: { nodeId, data } });
|
|
274
|
-
}
|
|
275
|
-
}
|
|
276
|
-
|
|
277
404
|
class StateManager {
|
|
278
|
-
apply(state, update
|
|
279
|
-
|
|
405
|
+
apply(state, update) {
|
|
406
|
+
return {
|
|
280
407
|
...state,
|
|
281
408
|
...typeof update === "function" ? update(state) : update
|
|
282
409
|
};
|
|
283
|
-
writer.write({ type: "data-state", data: newState });
|
|
284
|
-
return newState;
|
|
285
410
|
}
|
|
286
|
-
resolve(initialState, existingState,
|
|
411
|
+
resolve(initialState, existingState, emit) {
|
|
287
412
|
if (this.isStateFactory(initialState)) {
|
|
288
413
|
const newState2 = initialState(existingState);
|
|
289
|
-
|
|
414
|
+
emit({ type: "state", state: newState2 });
|
|
290
415
|
return newState2;
|
|
291
416
|
}
|
|
292
417
|
const newState = existingState ?? initialState;
|
|
293
|
-
|
|
418
|
+
emit({ type: "state", state: newState });
|
|
294
419
|
return newState;
|
|
295
420
|
}
|
|
296
421
|
isStateFactory(initialState) {
|
|
@@ -303,6 +428,7 @@ class Graph {
|
|
|
303
428
|
nodeRegistry = new Map;
|
|
304
429
|
edgeRegistry = new Map;
|
|
305
430
|
subgraphRegistry = new Map;
|
|
431
|
+
middlewares = [];
|
|
306
432
|
constructor() {
|
|
307
433
|
this.registerBuiltInNodes();
|
|
308
434
|
}
|
|
@@ -334,8 +460,26 @@ class Graph {
|
|
|
334
460
|
get subgraphs() {
|
|
335
461
|
return this.subgraphRegistry;
|
|
336
462
|
}
|
|
463
|
+
use(middleware) {
|
|
464
|
+
this.middlewares.push(middleware);
|
|
465
|
+
return this;
|
|
466
|
+
}
|
|
337
467
|
compile(options = {}) {
|
|
338
|
-
|
|
468
|
+
const graphMiddleware = [];
|
|
469
|
+
const nodeMiddleware = [];
|
|
470
|
+
const stateMiddleware = [];
|
|
471
|
+
const eventMiddleware = [];
|
|
472
|
+
for (const mw of this.middlewares) {
|
|
473
|
+
if (mw.graph)
|
|
474
|
+
graphMiddleware.push(mw.graph);
|
|
475
|
+
if (mw.node)
|
|
476
|
+
nodeMiddleware.push(mw.node);
|
|
477
|
+
if (mw.state)
|
|
478
|
+
stateMiddleware.push(mw.state);
|
|
479
|
+
if (mw.event)
|
|
480
|
+
eventMiddleware.push(mw.event);
|
|
481
|
+
}
|
|
482
|
+
return new CompiledGraph(this.nodeRegistry, this.edgeRegistry, this.subgraphRegistry, options, graphMiddleware, nodeMiddleware, stateMiddleware, eventMiddleware);
|
|
339
483
|
}
|
|
340
484
|
toMermaid(options) {
|
|
341
485
|
const generator = new MermaidGenerator(this.nodeRegistry, this.edgeRegistry, this.subgraphRegistry);
|
|
@@ -468,6 +612,7 @@ export {
|
|
|
468
612
|
isGraphDataPart,
|
|
469
613
|
graph,
|
|
470
614
|
consumeAndMergeStream,
|
|
615
|
+
composeMiddleware,
|
|
471
616
|
SuspenseError,
|
|
472
617
|
RedisStorage,
|
|
473
618
|
InMemoryStorage,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export declare function composeMiddleware<Ctx, R = void>(middlewares: Array<(ctx: Ctx, next: () => Promise<R>) => Promise<R>>, action: (ctx: Ctx) => Promise<R>): (ctx: Ctx) => Promise<R>;
|
package/dist/types.d.ts
CHANGED
|
@@ -3,7 +3,7 @@ export declare namespace GraphSDK {
|
|
|
3
3
|
type StateUpdate<State> = Partial<State> | ((state: State) => Partial<State>);
|
|
4
4
|
interface SubgraphOptions<ParentState extends Record<string, unknown>, ChildState extends Record<string, unknown>> {
|
|
5
5
|
input: (parentState: ParentState) => ChildState;
|
|
6
|
-
output: (childState: ChildState, parentState: ParentState) =>
|
|
6
|
+
output: (childState: ChildState, parentState: ParentState) => StateUpdate<ParentState>;
|
|
7
7
|
}
|
|
8
8
|
interface Graph<State extends Record<string, unknown>, NodeKeys extends string> {
|
|
9
9
|
nodes: Map<NodeKeys, Node<State, NodeKeys>>;
|
|
@@ -15,7 +15,7 @@ export declare namespace GraphSDK {
|
|
|
15
15
|
state: () => Readonly<State>;
|
|
16
16
|
writer: UIMessageStreamWriter;
|
|
17
17
|
suspense: (data?: unknown) => never;
|
|
18
|
-
update: (update: StateUpdate<State>) => void
|
|
18
|
+
update: (update: StateUpdate<State>) => Promise<void>;
|
|
19
19
|
}) => Promise<void> | void;
|
|
20
20
|
}
|
|
21
21
|
interface Edge<State extends Record<string, unknown>, NodeKeys extends string> {
|
|
@@ -37,27 +37,58 @@ export declare namespace GraphSDK {
|
|
|
37
37
|
state: State;
|
|
38
38
|
currentNodes: Node<State, NodeKeys>[];
|
|
39
39
|
suspendedNodes: Node<State, NodeKeys>[];
|
|
40
|
-
writer:
|
|
40
|
+
writer: Writer;
|
|
41
|
+
emit: (event: GraphEvent<State, NodeKeys>) => void;
|
|
41
42
|
}
|
|
42
43
|
interface GraphOptions<State extends Record<string, unknown>, NodeKeys extends string> {
|
|
43
44
|
storage?: GraphStorage<State, NodeKeys>;
|
|
44
|
-
onFinish?: (args: {
|
|
45
|
-
state: State;
|
|
46
|
-
}) => Promise<void> | void;
|
|
47
|
-
onStart?: (args: {
|
|
48
|
-
state: State;
|
|
49
|
-
writer: UIMessageStreamWriter;
|
|
50
|
-
}) => Promise<void> | void;
|
|
51
45
|
}
|
|
52
46
|
interface CompileOptions<State extends Record<string, unknown>, NodeKeys extends string> {
|
|
53
47
|
storage?: GraphStorage<State, NodeKeys>;
|
|
54
|
-
onFinish?: (args: {
|
|
55
|
-
state: State;
|
|
56
|
-
}) => Promise<void> | void;
|
|
57
|
-
onStart?: (args: {
|
|
58
|
-
state: State;
|
|
59
|
-
writer: UIMessageStreamWriter;
|
|
60
|
-
}) => Promise<void> | void;
|
|
61
48
|
}
|
|
62
49
|
type Writer = Parameters<Parameters<typeof createUIMessageStream>[0]['execute']>[0]['writer'];
|
|
50
|
+
interface GraphMiddlewareContext<State extends Record<string, unknown>, NodeKeys extends string> {
|
|
51
|
+
runId: string;
|
|
52
|
+
state: () => Readonly<State>;
|
|
53
|
+
writer: Writer;
|
|
54
|
+
isResume: boolean;
|
|
55
|
+
}
|
|
56
|
+
interface NodeMiddlewareContext<State extends Record<string, unknown>, NodeKeys extends string> {
|
|
57
|
+
runId: string;
|
|
58
|
+
nodeId: NodeKeys;
|
|
59
|
+
state: () => Readonly<State>;
|
|
60
|
+
writer: Writer;
|
|
61
|
+
isSubgraph: boolean;
|
|
62
|
+
}
|
|
63
|
+
interface StateMiddlewareContext<State extends Record<string, unknown>, NodeKeys extends string> {
|
|
64
|
+
runId: string;
|
|
65
|
+
nodeId: NodeKeys | null;
|
|
66
|
+
currentState: Readonly<State>;
|
|
67
|
+
update: StateUpdate<State>;
|
|
68
|
+
resolvedUpdate: Partial<State>;
|
|
69
|
+
}
|
|
70
|
+
type GraphMiddleware<State extends Record<string, unknown>, NodeKeys extends string> = (ctx: GraphMiddlewareContext<State, NodeKeys>, next: () => Promise<void>) => Promise<void>;
|
|
71
|
+
type NodeMiddleware<State extends Record<string, unknown>, NodeKeys extends string> = (ctx: NodeMiddlewareContext<State, NodeKeys>, next: () => Promise<void>) => Promise<void>;
|
|
72
|
+
type StateMiddleware<State extends Record<string, unknown>, NodeKeys extends string> = (ctx: StateMiddlewareContext<State, NodeKeys>, next: () => Promise<Partial<State>>) => Promise<Partial<State>>;
|
|
73
|
+
type GraphEvent<State extends Record<string, unknown>, NodeKeys extends string> = {
|
|
74
|
+
type: 'state';
|
|
75
|
+
state: State;
|
|
76
|
+
} | {
|
|
77
|
+
type: 'node:start';
|
|
78
|
+
nodeId: NodeKeys;
|
|
79
|
+
} | {
|
|
80
|
+
type: 'node:end';
|
|
81
|
+
nodeId: NodeKeys;
|
|
82
|
+
} | {
|
|
83
|
+
type: 'node:suspense';
|
|
84
|
+
nodeId: NodeKeys;
|
|
85
|
+
data: unknown;
|
|
86
|
+
};
|
|
87
|
+
type EventMiddleware<State extends Record<string, unknown>, NodeKeys extends string> = (event: GraphEvent<State, NodeKeys>, next: () => void) => void;
|
|
88
|
+
interface Middleware<State extends Record<string, unknown>, NodeKeys extends string> {
|
|
89
|
+
graph?: GraphMiddleware<State, NodeKeys>;
|
|
90
|
+
node?: NodeMiddleware<State, NodeKeys>;
|
|
91
|
+
state?: StateMiddleware<State, NodeKeys>;
|
|
92
|
+
event?: EventMiddleware<State, NodeKeys>;
|
|
93
|
+
}
|
|
63
94
|
}
|