graphai 0.0.4 → 0.0.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. package/.eslintrc.js +27 -46
  2. package/.github/workflows/node.js.yml +1 -0
  3. package/README.md +42 -10
  4. package/lib/graphai.d.ts +57 -27
  5. package/lib/graphai.js +157 -39
  6. package/package.json +11 -5
  7. package/samples/agents/arxiv_agent.ts +46 -0
  8. package/samples/agents/slashgpt_agent.ts +21 -0
  9. package/samples/express.ts +47 -0
  10. package/samples/graphs/arxiv.yml +29 -0
  11. package/{tests → samples}/sample_gpt.ts +11 -11
  12. package/samples/sample_paper_ai.ts +26 -0
  13. package/src/graphai.ts +213 -77
  14. package/tests/agents/agents.ts +24 -0
  15. package/tests/graphai/test_dispatch.ts +42 -0
  16. package/tests/graphai/test_http_client.ts +40 -0
  17. package/tests/{test_multiple_functions.ts → graphai/test_multiple_functions.ts} +13 -14
  18. package/tests/graphai/test_sample_flow.ts +72 -0
  19. package/tests/graphs/test_dispatch.yml +24 -0
  20. package/tests/graphs/test_error.yml +22 -0
  21. package/tests/graphs/test_multiple_functions_1.yml +8 -2
  22. package/tests/graphs/test_source.yml +18 -0
  23. package/tests/graphs/test_source2.yml +17 -0
  24. package/tests/graphs/test_timeout.yml +22 -0
  25. package/tests/http-server/README.md +10 -0
  26. package/tests/http-server/docs/llm.json +4 -0
  27. package/tests/http-server/docs/llm2.json +4 -0
  28. package/tests/{file_utils.ts → utils/file_utils.ts} +10 -1
  29. package/tests/utils/runner.ts +40 -0
  30. package/tsconfig.json +5 -2
  31. package/tests/test_sample_flow.ts +0 -63
  32. /package/{tests/graphs/sample3.yml → samples/graphs/slash_gpt.yml} +0 -0
  33. /package/tests/graphs/{sample1.yml → test_base.yml} +0 -0
  34. /package/tests/graphs/{sample2.yml → test_retry.yml} +0 -0
  35. /package/tests/{utils.ts → utils/utils.ts} +0 -0
@@ -0,0 +1,46 @@
1
+ import search from "arXiv-api-ts";
2
+
3
+ import { AgentFunction } from "@/graphai";
4
+
5
+ type arxivData = { id: string; title: string; summary: string };
6
+
7
+ const search_arxiv_papers = async (keywords: string[], limit = 10) => {
8
+ const includes = keywords.map((k) => {
9
+ return { name: k };
10
+ });
11
+ const papers = await search({
12
+ searchQueryParams: [
13
+ {
14
+ include: includes,
15
+ },
16
+ ],
17
+ sortBy: "lastUpdatedDate",
18
+ sortOrder: "descending",
19
+ start: 0,
20
+ maxResults: limit,
21
+ });
22
+ return papers.entries || [];
23
+ };
24
+
25
+ export const arxivAgent: AgentFunction<{ keywords: string[]; limit: number }, arxivData[]> = async (context) => {
26
+ const { keywords, limit } = context.params;
27
+ const arxivResult = await search_arxiv_papers(keywords, limit);
28
+ // console.log("executing", arxivResult, context.params.keywords);
29
+
30
+ const result = arxivResult.map((r: any) => {
31
+ const { id, title, summary } = r;
32
+ return { id, title, summary };
33
+ });
34
+ return result;
35
+ };
36
+
37
+ export const arxiv2TextAgent: AgentFunction<{}, string, string[]> = async (context) => {
38
+ const result = (context?.payload?.inputData || [])
39
+ .map((r: any) => {
40
+ const { id, title, summary } = r;
41
+ return ["id:", id, "title:", title, "summary:", summary].join("\n");
42
+ })
43
+ .join("\n\n\n");
44
+
45
+ return result;
46
+ };
@@ -0,0 +1,21 @@
1
+ import path from "path";
2
+ import { AgentFunction } from "@/graphai";
3
+ import { ChatSession, ChatConfig, ManifestData } from "slashgpt";
4
+
5
+ const config = new ChatConfig(path.resolve(__dirname));
6
+
7
+ export const slashGPTAgent: AgentFunction<{ manifest: ManifestData; prompt: string }, { answer: string }> = async (context) => {
8
+ console.log("executing", context.nodeId, context);
9
+ const session = new ChatSession(config, context.params?.manifest ?? {});
10
+
11
+ const prompt = [context.params?.prompt, context.payload.inputData].join("\n\n");
12
+ session.append_user_question(prompt);
13
+
14
+ await session.call_loop(() => {});
15
+ const message = session.history.last_message();
16
+ if (message === undefined) {
17
+ throw new Error("No message in the history");
18
+ }
19
+ const result = { answer: message.content };
20
+ return result;
21
+ };
@@ -0,0 +1,47 @@
1
+ // npx ts-node samples/express.ts
2
+ import { GraphAI, AgentFunction } from "@/graphai";
3
+
4
+ import express from "express";
5
+
6
+ const app = express();
7
+
8
+ const graphAISample = async (req: express.Request, res: express.Response) => {
9
+ const graph_data = {
10
+ nodes: {
11
+ node1: {
12
+ params: {},
13
+ },
14
+ },
15
+ concurrency: 8,
16
+ };
17
+ const testFunction: AgentFunction<Record<string, string>> = async (context) => {
18
+ console.log("hello");
19
+ return {};
20
+ };
21
+ const graph = new GraphAI(graph_data, testFunction);
22
+ const response = await graph.run();
23
+ res.json({ result: response });
24
+ res.end();
25
+ };
26
+
27
+ const hello = async (req: express.Request, res: express.Response) => {
28
+ const { params, query } = req;
29
+ res.json({
30
+ result: [
31
+ {
32
+ message: "hello",
33
+ params,
34
+ query,
35
+ },
36
+ ],
37
+ });
38
+ res.end();
39
+ };
40
+
41
+ app.use(express.json());
42
+ app.get("/", hello);
43
+ app.get("/mock", graphAISample);
44
+
45
+ const server = app.listen(8080, () => {
46
+ console.log("Running Server");
47
+ });
@@ -0,0 +1,29 @@
1
+ nodes:
2
+ searchArxiv:
3
+ params:
4
+ keywords:
5
+ - llm
6
+ - gpt
7
+ limit: 10
8
+ functionName: arxivAgent
9
+ arxiv2TextAgent:
10
+ inputs: [searchArxiv]
11
+ functionName: arxiv2TextAgent
12
+ payloadMapping:
13
+ searchArxiv: inputData
14
+ slashGPTAgent:
15
+ inputs: [arxiv2TextAgent]
16
+ payloadMapping:
17
+ arxiv2TextAgent: inputData
18
+ functionName: slashGPTAgent
19
+ params:
20
+ prompt: |
21
+ 与えられたそれぞれの論文の要点をまとめ、以下の項目で日本語で出力せよ。それぞれの項目は最大でも180文字以内に要約せよ。
22
+ ```
23
+ 論文名:タイトルの日本語訳
24
+ キーワード:この論文のキーワード
25
+ 課題:この論文が解決する課題
26
+ 手法:この論文が提案する手法,
27
+ 結果:提案手法によって得られた結果
28
+ ```
29
+
@@ -1,12 +1,12 @@
1
1
  import path from "path";
2
- import { GraphAI, NodeExecuteContext } from "../src/graphai";
3
- import { ChatSession, ChatConfig } from "slashgpt";
4
- import { readManifestData } from "./file_utils";
2
+ import { GraphAI, AgentFunction } from "@/graphai";
3
+ import { ChatSession, ChatConfig, ManifestData } from "slashgpt";
4
+ import { readGraphaiData } from "~/utils/file_utils";
5
5
 
6
6
  const config = new ChatConfig(path.resolve(__dirname));
7
7
 
8
- const testFunction = async (context: NodeExecuteContext<Record<string, string>>) => {
9
- console.log("executing", context.nodeId, context.params, context.payload);
8
+ const slashGPTAgent: AgentFunction<{ manifest: ManifestData; prompt: string }, { answer: string }> = async (context) => {
9
+ console.log("executing", context.nodeId, context.params);
10
10
  const session = new ChatSession(config, context.params.manifest ?? {});
11
11
  const prompt = Object.keys(context.payload).reduce((prompt, key) => {
12
12
  return prompt.replace("${" + key + "}", context.payload[key]!["answer"]);
@@ -19,19 +19,19 @@ const testFunction = async (context: NodeExecuteContext<Record<string, string>>)
19
19
  throw new Error("No message in the history");
20
20
  }
21
21
  const result = { answer: message.content };
22
- console.log(result);
23
22
  return result;
24
23
  };
25
24
 
26
- const test = async (file: string) => {
25
+ const runAgent = async (file: string) => {
27
26
  const file_path = path.resolve(__dirname) + file;
28
- const graph_data = readManifestData(file_path);
29
- const graph = new GraphAI(graph_data, testFunction);
30
- await graph.run();
27
+ const graph_data = readGraphaiData(file_path);
28
+ const graph = new GraphAI(graph_data, slashGPTAgent);
29
+ const result = await graph.run();
30
+ console.log(result);
31
31
  };
32
32
 
33
33
  const main = async () => {
34
- await test("/graphs/sample3.yml");
34
+ await runAgent("/graphs/slash_gpt.yml");
35
35
  console.log("COMPLETE 1");
36
36
  };
37
37
  main();
@@ -0,0 +1,26 @@
1
+ import path from "path";
2
+ import search from "arXiv-api-ts";
3
+
4
+ import { GraphAI, AgentFunction } from "@/graphai";
5
+ import { readGraphaiData } from "~/utils/file_utils";
6
+
7
+ import { slashGPTAgent } from "./agents/slashgpt_agent";
8
+ import { arxivAgent, arxiv2TextAgent } from "./agents/arxiv_agent";
9
+
10
+ export const parrotingAgent: AgentFunction = async (context) => {
11
+ return {};
12
+ };
13
+
14
+ const runAgent = async (file: string) => {
15
+ const file_path = path.resolve(__dirname) + file;
16
+ const graph_data = readGraphaiData(file_path);
17
+ const graph = new GraphAI(graph_data, { default: parrotingAgent, arxivAgent: arxivAgent, arxiv2TextAgent, slashGPTAgent });
18
+ const result = await graph.run();
19
+ console.log(result);
20
+ };
21
+
22
+ const main = async () => {
23
+ await runAgent("/graphs/arxiv.yml");
24
+ console.log("COMPLETE 1");
25
+ };
26
+ main();
package/src/graphai.ts CHANGED
@@ -1,68 +1,90 @@
1
- import { AssertionError } from "assert";
2
-
3
1
  export enum NodeState {
4
2
  Waiting,
5
3
  Executing,
6
4
  Failed,
7
5
  TimedOut,
8
6
  Completed,
7
+ Injected,
8
+ Dispatched,
9
9
  }
10
10
  type ResultData<ResultType = Record<string, any>> = ResultType | undefined;
11
11
  type ResultDataDictonary<ResultType = Record<string, any>> = Record<string, ResultData<ResultType>>;
12
12
 
13
- export type NodeDataParams = Record<string, any>; // App-specific parameters
13
+ export type NodeDataParams<ParamsType = Record<string, any>> = ParamsType; // Agent-specific parameters
14
14
 
15
15
  type NodeData = {
16
- inputs: undefined | Array<string>;
16
+ inputs?: Array<string>;
17
17
  params: NodeDataParams;
18
- retry: undefined | number;
19
- timeout: undefined | number; // msec
20
- functionName: undefined | string;
18
+ payloadMapping?: Record<string, string>;
19
+ retry?: number;
20
+ timeout?: number; // msec
21
+ agentId?: string;
22
+ source?: boolean;
23
+ dispatch?: Record<string, string>; // route to node
21
24
  };
22
25
 
23
- type GraphData = {
26
+ export type GraphData = {
24
27
  nodes: Record<string, NodeData>;
25
- concurrency: number;
28
+ concurrency?: number;
29
+ };
30
+
31
+ export type TransactionLog = {
32
+ nodeId: string;
33
+ state: NodeState;
34
+ startTime: number;
35
+ endTime?: number;
36
+ retryCount: number;
37
+ agentId?: string;
38
+ params?: NodeDataParams;
39
+ payload?: ResultDataDictonary<ResultData>;
40
+ errorMessage?: string;
41
+ result?: ResultData;
26
42
  };
27
43
 
28
- export type NodeExecuteContext<ResultType> = {
44
+ export type AgentFunctionContext<ParamsType, ResultType, PreviousResultType> = {
29
45
  nodeId: string;
30
46
  retry: number;
31
- params: NodeDataParams;
32
- payload: ResultDataDictonary<ResultType>;
47
+ params: NodeDataParams<ParamsType>;
48
+ payload: ResultDataDictonary<PreviousResultType>;
33
49
  };
34
50
 
35
- type NodeExecute<ResultType> = (context: NodeExecuteContext<ResultType>) => Promise<ResultData<ResultType>>;
51
+ export type AgentFunction<ParamsType = Record<string, any>, ResultType = Record<string, any>, PreviousResultType = Record<string, any>> = (
52
+ context: AgentFunctionContext<ParamsType, ResultType, PreviousResultType>,
53
+ ) => Promise<ResultData<ResultType>>;
54
+
55
+ export type AgentFunctionDictonary = Record<string, AgentFunction<any, any, any>>;
36
56
 
37
- class Node<ResultType = Record<string, any>> {
57
+ class Node {
38
58
  public nodeId: string;
39
- public params: NodeDataParams; // App-specific parameters
59
+ public params: NodeDataParams; // Agent-specific parameters
40
60
  public inputs: Array<string>; // List of nodes this node needs data from.
61
+ public payloadMapping: Record<string, string>;
41
62
  public pendings: Set<string>; // List of nodes this node is waiting data from.
42
- public waitlist: Set<string>; // List of nodes which need data from this node.
43
- public state: NodeState;
44
- public functionName: string;
45
- public result: ResultData<ResultType>;
63
+ public waitlist = new Set<string>(); // List of nodes which need data from this node.
64
+ public state = NodeState.Waiting;
65
+ public agentId: string;
66
+ public result: ResultData = undefined;
46
67
  public retryLimit: number;
47
- public retryCount: number;
68
+ public retryCount: number = 0;
48
69
  public transactionId: undefined | number; // To reject callbacks from timed-out transactions
49
- public timeout: number; // msec
70
+ public timeout?: number; // msec
71
+ public error?: Error;
72
+ public source: boolean;
73
+ public dispatch?: Record<string, string>; // outputId to nodeId mapping
50
74
 
51
- private graph: GraphAI<ResultType>;
75
+ private graph: GraphAI;
52
76
 
53
- constructor(nodeId: string, data: NodeData, graph: GraphAI<ResultType>) {
77
+ constructor(nodeId: string, data: NodeData, graph: GraphAI) {
54
78
  this.nodeId = nodeId;
55
79
  this.inputs = data.inputs ?? [];
80
+ this.payloadMapping = data.payloadMapping ?? {};
56
81
  this.pendings = new Set(this.inputs);
57
82
  this.params = data.params;
58
- this.waitlist = new Set<string>();
59
- this.state = NodeState.Waiting;
60
- this.functionName = data.functionName ?? "default";
61
- this.result = undefined;
83
+ this.agentId = data.agentId ?? "default";
62
84
  this.retryLimit = data.retry ?? 0;
63
- this.retryCount = 0;
64
- this.timeout = data.timeout ?? 0;
65
-
85
+ this.timeout = data.timeout;
86
+ this.source = data.source === true;
87
+ this.dispatch = data.dispatch;
66
88
  this.graph = graph;
67
89
  }
68
90
 
@@ -70,101 +92,171 @@ class Node<ResultType = Record<string, any>> {
70
92
  return `${this.nodeId}: ${this.state} ${[...this.waitlist]}`;
71
93
  }
72
94
 
73
- private retry(state: NodeState, result: ResultData<ResultType>) {
95
+ private retry(state: NodeState, error: Error) {
74
96
  if (this.retryCount < this.retryLimit) {
75
97
  this.retryCount++;
76
98
  this.execute();
77
99
  } else {
78
100
  this.state = state;
79
- this.result = result;
101
+ this.result = undefined;
102
+ this.error = error;
103
+ this.transactionId = undefined; // This is necessary for timeout case
80
104
  this.graph.removeRunning(this);
81
105
  }
82
106
  }
83
107
 
84
108
  public removePending(nodeId: string) {
85
109
  this.pendings.delete(nodeId);
86
- this.pushQueueIfReady();
110
+ if (this.graph.isRunning) {
111
+ this.pushQueueIfReady();
112
+ }
87
113
  }
88
114
 
89
115
  public payload() {
90
- return this.inputs.reduce((results: ResultDataDictonary<ResultType>, nodeId) => {
91
- results[nodeId] = this.graph.nodes[nodeId].result;
116
+ return this.inputs.reduce((results: ResultDataDictonary, nodeId) => {
117
+ if (this.payloadMapping && this.payloadMapping[nodeId]) {
118
+ results[this.payloadMapping[nodeId]] = this.graph.nodes[nodeId].result;
119
+ } else {
120
+ results[nodeId] = this.graph.nodes[nodeId].result;
121
+ }
92
122
  return results;
93
123
  }, {});
94
124
  }
95
125
 
96
126
  public pushQueueIfReady() {
97
- if (this.pendings.size === 0) {
127
+ if (this.pendings.size === 0 && !this.source) {
98
128
  this.graph.pushQueue(this);
99
129
  }
100
130
  }
101
131
 
132
+ public injectResult(result: ResultData) {
133
+ if (this.source) {
134
+ const log: TransactionLog = {
135
+ nodeId: this.nodeId,
136
+ retryCount: this.retryCount,
137
+ state: NodeState.Injected,
138
+ startTime: Date.now(),
139
+ };
140
+ log.endTime = log.startTime;
141
+ this.graph.appendLog(log);
142
+ this.setResult(result, NodeState.Injected);
143
+ } else {
144
+ console.error("- injectResult called on non-source node.", this.nodeId);
145
+ }
146
+ }
147
+
148
+ private setResult(result: ResultData, state: NodeState) {
149
+ this.state = state;
150
+ this.result = result;
151
+ this.waitlist.forEach((nodeId) => {
152
+ const node = this.graph.nodes[nodeId];
153
+ // Todo: Avoid running before Run()
154
+ node.removePending(this.nodeId);
155
+ });
156
+ }
157
+
102
158
  public async execute() {
159
+ const payload = this.payload();
160
+ const log: TransactionLog = {
161
+ nodeId: this.nodeId,
162
+ retryCount: this.retryCount,
163
+ state: NodeState.Executing,
164
+ startTime: Date.now(),
165
+ agentId: this.agentId,
166
+ params: this.params,
167
+ payload,
168
+ };
169
+ this.graph.appendLog(log);
103
170
  this.state = NodeState.Executing;
104
- const transactionId = Date.now();
171
+ const transactionId = log.startTime;
105
172
  this.transactionId = transactionId;
106
173
 
107
- if (this.timeout > 0) {
174
+ if (this.timeout && this.timeout > 0) {
108
175
  setTimeout(() => {
109
176
  if (this.state === NodeState.Executing && this.transactionId === transactionId) {
110
- console.log("*** timeout", this.timeout);
111
- this.retry(NodeState.TimedOut, undefined);
177
+ console.log(`-- ${this.nodeId}: timeout ${this.timeout}`);
178
+ log.errorMessage = "Timeout";
179
+ log.state = NodeState.TimedOut;
180
+ log.endTime = Date.now();
181
+ this.retry(NodeState.TimedOut, Error("Timeout"));
112
182
  }
113
183
  }, this.timeout);
114
184
  }
115
185
 
116
186
  try {
117
- const callback = this.graph.getCallback(this.functionName);
187
+ const callback = this.graph.getCallback(this.agentId);
118
188
  const result = await callback({
119
189
  nodeId: this.nodeId,
120
190
  retry: this.retryCount,
121
191
  params: this.params,
122
- payload: this.payload(),
192
+ payload,
123
193
  });
124
194
  if (this.transactionId !== transactionId) {
125
- console.log("****** transactionId mismatch (success)");
195
+ console.log(`-- ${this.nodeId}: transactionId mismatch`);
126
196
  return;
127
197
  }
128
- this.state = NodeState.Completed;
129
- this.result = result;
130
- this.waitlist.forEach((nodeId) => {
131
- const node = this.graph.nodes[nodeId];
132
- node.removePending(this.nodeId);
133
- });
198
+
199
+ log.endTime = Date.now();
200
+ log.result = result;
201
+
202
+ const dispatch = this.dispatch;
203
+ if (dispatch !== undefined) {
204
+ Object.keys(result).forEach((outputId) => {
205
+ const nodeId = dispatch[outputId];
206
+ this.graph.injectResult(nodeId, result[outputId]);
207
+ });
208
+ log.state = NodeState.Dispatched;
209
+ this.state = NodeState.Dispatched;
210
+ this.graph.removeRunning(this);
211
+ return;
212
+ }
213
+ log.state = NodeState.Completed;
214
+ this.setResult(result, NodeState.Completed);
134
215
  this.graph.removeRunning(this);
135
- } catch (e) {
216
+ } catch (error) {
136
217
  if (this.transactionId !== transactionId) {
137
- console.log("****** transactionId mismatch (failed)");
218
+ console.log(`-- ${this.nodeId}: transactionId mismatch(error)`);
138
219
  return;
139
220
  }
140
- this.retry(NodeState.Failed, undefined);
221
+ log.state = NodeState.Failed;
222
+ log.endTime = Date.now();
223
+ if (error instanceof Error) {
224
+ log.errorMessage = error.message;
225
+ this.retry(NodeState.Failed, error);
226
+ } else {
227
+ console.error(`-- ${this.nodeId}: Unexpecrted error was caught`);
228
+ log.errorMessage = "Unknown";
229
+ this.retry(NodeState.Failed, Error("Unknown"));
230
+ }
141
231
  }
142
232
  }
143
233
  }
144
234
 
145
- type GraphNodes<ResultType> = Record<string, Node<ResultType>>;
235
+ type GraphNodes = Record<string, Node>;
146
236
 
147
- type NodeExecuteDictonary<ResultType> = Record<string, NodeExecute<ResultType>>;
237
+ const defaultConcurrency = 8;
148
238
 
149
- export class GraphAI<ResultType = Record<string, any>> {
150
- public nodes: GraphNodes<ResultType>;
151
- public callbackDictonary: NodeExecuteDictonary<ResultType>;
152
- private runningNodes: Set<string>;
153
- private nodeQueue: Array<Node<ResultType>>;
239
+ export class GraphAI {
240
+ public nodes: GraphNodes;
241
+ public callbackDictonary: AgentFunctionDictonary;
242
+ public isRunning = false;
243
+ private runningNodes = new Set<string>();
244
+ private nodeQueue: Array<Node> = [];
154
245
  private onComplete: () => void;
155
246
  private concurrency: number;
247
+ private logs: Array<TransactionLog> = [];
156
248
 
157
- constructor(data: GraphData, callbackDictonary: NodeExecuteDictonary<ResultType> | NodeExecute<ResultType>) {
249
+ constructor(data: GraphData, callbackDictonary: AgentFunctionDictonary | AgentFunction<any, any, any>) {
158
250
  this.callbackDictonary = typeof callbackDictonary === "function" ? { default: callbackDictonary } : callbackDictonary;
159
251
  if (this.callbackDictonary["default"] === undefined) {
160
252
  throw new Error("No default function");
161
253
  }
162
- this.concurrency = data.concurrency ?? 2;
163
- this.runningNodes = new Set<string>();
164
- this.nodeQueue = [];
165
- this.onComplete = () => {};
166
- this.nodes = Object.keys(data.nodes).reduce((nodes: GraphNodes<ResultType>, nodeId: string) => {
167
- nodes[nodeId] = new Node<ResultType>(nodeId, data.nodes[nodeId], this);
254
+ this.concurrency = data.concurrency ?? defaultConcurrency;
255
+ this.onComplete = () => {
256
+ console.error("-- SOMETHING IS WRONG: onComplete is called without run()");
257
+ };
258
+ this.nodes = Object.keys(data.nodes).reduce((nodes: GraphNodes, nodeId: string) => {
259
+ nodes[nodeId] = new Node(nodeId, data.nodes[nodeId], this);
168
260
  return nodes;
169
261
  }, {});
170
262
 
@@ -178,9 +270,9 @@ export class GraphAI<ResultType = Record<string, any>> {
178
270
  });
179
271
  }
180
272
 
181
- public getCallback(functionName: string) {
182
- if (functionName && this.callbackDictonary[functionName]) {
183
- return this.callbackDictonary[functionName];
273
+ public getCallback(agentId: string) {
274
+ if (agentId && this.callbackDictonary[agentId]) {
275
+ return this.callbackDictonary[agentId];
184
276
  }
185
277
  return this.callbackDictonary["default"];
186
278
  }
@@ -193,7 +285,31 @@ export class GraphAI<ResultType = Record<string, any>> {
193
285
  .join("\n");
194
286
  }
195
287
 
288
+ public results() {
289
+ return Object.keys(this.nodes).reduce((results: ResultDataDictonary, nodeId) => {
290
+ const node = this.nodes[nodeId];
291
+ if (node.result !== undefined) {
292
+ results[nodeId] = node.result;
293
+ }
294
+ return results;
295
+ }, {});
296
+ }
297
+
298
+ public errors() {
299
+ return Object.keys(this.nodes).reduce((errors: Record<string, Error>, nodeId) => {
300
+ const node = this.nodes[nodeId];
301
+ if (node.error !== undefined) {
302
+ errors[nodeId] = node.error;
303
+ }
304
+ return errors;
305
+ }, {});
306
+ }
307
+
196
308
  public async run() {
309
+ if (this.isRunning) {
310
+ console.error("-- Already Running");
311
+ }
312
+ this.isRunning = true;
197
313
  // Nodes without pending data should run immediately.
198
314
  Object.keys(this.nodes).forEach((nodeId) => {
199
315
  const node = this.nodes[nodeId];
@@ -202,21 +318,24 @@ export class GraphAI<ResultType = Record<string, any>> {
202
318
 
203
319
  return new Promise((resolve, reject) => {
204
320
  this.onComplete = () => {
205
- const results = Object.keys(this.nodes).reduce((results: ResultDataDictonary<ResultType>, nodeId) => {
206
- results[nodeId] = this.nodes[nodeId].result;
207
- return results;
208
- }, {});
209
- resolve(results);
321
+ this.isRunning = false;
322
+ const errors = this.errors();
323
+ const nodeIds = Object.keys(errors);
324
+ if (nodeIds.length > 0) {
325
+ reject(errors[nodeIds[0]]);
326
+ } else {
327
+ resolve(this.results());
328
+ }
210
329
  };
211
330
  });
212
331
  }
213
332
 
214
- private runNode(node: Node<ResultType>) {
333
+ private runNode(node: Node) {
215
334
  this.runningNodes.add(node.nodeId);
216
335
  node.execute();
217
336
  }
218
337
 
219
- public pushQueue(node: Node<ResultType>) {
338
+ public pushQueue(node: Node) {
220
339
  if (this.runningNodes.size < this.concurrency) {
221
340
  this.runNode(node);
222
341
  } else {
@@ -224,7 +343,7 @@ export class GraphAI<ResultType = Record<string, any>> {
224
343
  }
225
344
  }
226
345
 
227
- public removeRunning(node: Node<ResultType>) {
346
+ public removeRunning(node: Node) {
228
347
  this.runningNodes.delete(node.nodeId);
229
348
  if (this.nodeQueue.length > 0) {
230
349
  const n = this.nodeQueue.shift();
@@ -236,4 +355,21 @@ export class GraphAI<ResultType = Record<string, any>> {
236
355
  this.onComplete();
237
356
  }
238
357
  }
358
+
359
+ public appendLog(log: TransactionLog) {
360
+ this.logs.push(log);
361
+ }
362
+
363
+ public transactionLogs() {
364
+ return this.logs;
365
+ }
366
+
367
+ public injectResult(nodeId: string, result: ResultData) {
368
+ const node = this.nodes[nodeId];
369
+ if (node) {
370
+ node.injectResult(result);
371
+ } else {
372
+ console.error("-- Invalid nodeId", nodeId);
373
+ }
374
+ }
239
375
  }
@@ -0,0 +1,24 @@
1
+ import { AgentFunction } from "@/graphai";
2
+ import { sleep } from "~/utils/utils";
3
+
4
+ export const testAgent: AgentFunction<{ delay: number; fail: boolean }> = async (context) => {
5
+ const { nodeId, retry, params, payload } = context;
6
+ console.log("executing", nodeId);
7
+ await sleep(params.delay / (retry + 1));
8
+
9
+ if (params.fail && retry < 2) {
10
+ const result = { [nodeId]: "failed" };
11
+ console.log("failed (intentional)", nodeId, retry);
12
+ throw new Error("Intentional Failure");
13
+ } else {
14
+ const result = Object.keys(payload).reduce(
15
+ (result, key) => {
16
+ result = { ...result, ...payload[key] };
17
+ return result;
18
+ },
19
+ { [nodeId]: "output" },
20
+ );
21
+ console.log("completing", nodeId);
22
+ return result;
23
+ }
24
+ };