@langgraph-js/pure-graph 1.0.2 → 1.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. package/.prettierrc +11 -0
  2. package/README.md +104 -10
  3. package/bun.lock +209 -0
  4. package/dist/adapter/hono/assistants.js +3 -9
  5. package/dist/adapter/hono/endpoint.js +1 -2
  6. package/dist/adapter/hono/runs.js +6 -39
  7. package/dist/adapter/hono/threads.js +5 -46
  8. package/dist/adapter/nextjs/endpoint.d.ts +1 -0
  9. package/dist/adapter/nextjs/endpoint.js +2 -0
  10. package/dist/adapter/nextjs/index.d.ts +1 -0
  11. package/dist/adapter/nextjs/index.js +2 -0
  12. package/dist/adapter/nextjs/router.d.ts +5 -0
  13. package/dist/adapter/nextjs/router.js +168 -0
  14. package/dist/adapter/{hono → nextjs}/zod.d.ts +5 -5
  15. package/dist/adapter/{hono → nextjs}/zod.js +22 -5
  16. package/dist/adapter/zod.d.ts +577 -0
  17. package/dist/adapter/zod.js +119 -0
  18. package/dist/createEndpoint.d.ts +1 -2
  19. package/dist/createEndpoint.js +4 -3
  20. package/dist/global.d.ts +6 -4
  21. package/dist/global.js +10 -5
  22. package/dist/graph/stream.d.ts +1 -1
  23. package/dist/graph/stream.js +18 -10
  24. package/dist/index.d.ts +1 -0
  25. package/dist/index.js +1 -0
  26. package/dist/queue/stream_queue.d.ts +5 -3
  27. package/dist/queue/stream_queue.js +4 -2
  28. package/dist/storage/index.d.ts +9 -4
  29. package/dist/storage/index.js +38 -3
  30. package/dist/storage/redis/queue.d.ts +39 -0
  31. package/dist/storage/redis/queue.js +130 -0
  32. package/dist/storage/sqlite/DB.d.ts +3 -0
  33. package/dist/storage/sqlite/DB.js +14 -0
  34. package/dist/storage/sqlite/checkpoint.d.ts +18 -0
  35. package/dist/storage/sqlite/checkpoint.js +374 -0
  36. package/dist/storage/sqlite/threads.d.ts +43 -0
  37. package/dist/storage/sqlite/threads.js +266 -0
  38. package/dist/storage/sqlite/type.d.ts +15 -0
  39. package/dist/storage/sqlite/type.js +1 -0
  40. package/dist/utils/createEntrypointGraph.d.ts +14 -0
  41. package/dist/utils/createEntrypointGraph.js +11 -0
  42. package/dist/utils/getGraph.js +3 -3
  43. package/examples/nextjs/README.md +36 -0
  44. package/examples/nextjs/app/api/langgraph/[...path]/route.ts +10 -0
  45. package/examples/nextjs/app/favicon.ico +0 -0
  46. package/examples/nextjs/app/globals.css +26 -0
  47. package/examples/nextjs/app/layout.tsx +34 -0
  48. package/examples/nextjs/app/page.tsx +211 -0
  49. package/examples/nextjs/next.config.ts +26 -0
  50. package/examples/nextjs/package.json +24 -0
  51. package/examples/nextjs/postcss.config.mjs +5 -0
  52. package/examples/nextjs/tsconfig.json +27 -0
  53. package/package.json +9 -4
  54. package/packages/agent-graph/demo.json +35 -0
  55. package/packages/agent-graph/package.json +18 -0
  56. package/packages/agent-graph/src/index.ts +47 -0
  57. package/packages/agent-graph/src/tools/tavily.ts +9 -0
  58. package/packages/agent-graph/src/tools.ts +38 -0
  59. package/packages/agent-graph/src/types.ts +42 -0
  60. package/pnpm-workspace.yaml +4 -0
  61. package/src/adapter/hono/assistants.ts +16 -33
  62. package/src/adapter/hono/endpoint.ts +1 -2
  63. package/src/adapter/hono/runs.ts +15 -51
  64. package/src/adapter/hono/threads.ts +15 -70
  65. package/src/adapter/nextjs/endpoint.ts +2 -0
  66. package/src/adapter/nextjs/index.ts +2 -0
  67. package/src/adapter/nextjs/router.ts +193 -0
  68. package/src/adapter/{hono → nextjs}/zod.ts +22 -5
  69. package/src/adapter/zod.ts +135 -0
  70. package/src/createEndpoint.ts +12 -5
  71. package/src/e.d.ts +3 -0
  72. package/src/global.ts +11 -6
  73. package/src/graph/stream.ts +20 -10
  74. package/src/index.ts +1 -0
  75. package/src/queue/stream_queue.ts +6 -5
  76. package/src/storage/index.ts +42 -4
  77. package/src/storage/redis/queue.ts +148 -0
  78. package/src/storage/sqlite/DB.ts +16 -0
  79. package/src/storage/sqlite/checkpoint.ts +503 -0
  80. package/src/storage/sqlite/threads.ts +366 -0
  81. package/src/storage/sqlite/type.ts +12 -0
  82. package/src/utils/createEntrypointGraph.ts +20 -0
  83. package/src/utils/getGraph.ts +3 -3
  84. package/test/graph/entrypoint.ts +21 -0
  85. package/test/graph/index.ts +45 -6
  86. package/test/hono.ts +5 -0
  87. package/test/test.ts +0 -10
@@ -0,0 +1,135 @@
1
+ import z from 'zod';
2
+
3
+ export const AssistantConfigurable = z
4
+ .object({
5
+ thread_id: z.string().optional(),
6
+ thread_ts: z.string().optional(),
7
+ })
8
+ .catchall(z.unknown());
9
+
10
+ export const AssistantConfig = z
11
+ .object({
12
+ tags: z.array(z.string()).optional(),
13
+ recursion_limit: z.number().int().optional(),
14
+ configurable: AssistantConfigurable.optional(),
15
+ })
16
+ .catchall(z.unknown())
17
+ .describe('The configuration of an assistant.');
18
+
19
+ export const Assistant = z.object({
20
+ assistant_id: z.string().uuid(),
21
+ graph_id: z.string(),
22
+ config: AssistantConfig,
23
+ created_at: z.string(),
24
+ updated_at: z.string(),
25
+ metadata: z.object({}).catchall(z.any()),
26
+ });
27
+
28
+ export const MetadataSchema = z
29
+ .object({
30
+ source: z.union([z.literal('input'), z.literal('loop'), z.literal('update'), z.string()]).optional(),
31
+ step: z.number().optional(),
32
+ writes: z.record(z.unknown()).nullable().optional(),
33
+ parents: z.record(z.string()).optional(),
34
+ })
35
+ .catchall(z.unknown());
36
+
37
+ export const SendSchema = z.object({
38
+ node: z.string(),
39
+ input: z.unknown().nullable(),
40
+ });
41
+
42
+ export const CommandSchema = z.object({
43
+ update: z
44
+ .union([z.record(z.unknown()), z.array(z.tuple([z.string(), z.unknown()]))])
45
+ .nullable()
46
+ .optional(),
47
+ resume: z.unknown().optional(),
48
+ goto: z.union([SendSchema, z.array(SendSchema), z.string(), z.array(z.string())]).optional(),
49
+ });
50
+
51
+ // 公共的查询参数验证 schema
52
+ export const PaginationQuerySchema = z.object({
53
+ limit: z.number().int().optional(),
54
+ offset: z.number().int().optional(),
55
+ });
56
+
57
+ export const ThreadIdParamSchema = z.object({
58
+ thread_id: z.string().uuid(),
59
+ });
60
+
61
+ export const RunIdParamSchema = z.object({
62
+ thread_id: z.string().uuid(),
63
+ run_id: z.string().uuid(),
64
+ });
65
+
66
+ // Assistants 相关的 schema
67
+ export const AssistantsSearchSchema = z.object({
68
+ graph_id: z.string().optional(),
69
+ metadata: MetadataSchema.optional(),
70
+ limit: z.number().int().optional(),
71
+ offset: z.number().int().optional(),
72
+ });
73
+
74
+ export const AssistantGraphQuerySchema = z.object({
75
+ xray: z.string().optional(),
76
+ });
77
+
78
+ // Runs 相关的 schema
79
+ export const RunStreamPayloadSchema = z
80
+ .object({
81
+ assistant_id: z.union([z.string().uuid(), z.string()]),
82
+ checkpoint_id: z.string().optional(),
83
+ input: z.any().optional(),
84
+ command: CommandSchema.optional(),
85
+ metadata: MetadataSchema.optional(),
86
+ config: AssistantConfig.optional(),
87
+ webhook: z.string().optional(),
88
+ interrupt_before: z.union([z.literal('*'), z.array(z.string())]).optional(),
89
+ interrupt_after: z.union([z.literal('*'), z.array(z.string())]).optional(),
90
+ on_disconnect: z.enum(['cancel', 'continue']).optional().default('continue'),
91
+ multitask_strategy: z.enum(['reject', 'rollback', 'interrupt', 'enqueue']).optional(),
92
+ stream_mode: z
93
+ .array(z.enum(['values', 'messages', 'messages-tuple', 'updates', 'events', 'debug', 'custom']))
94
+ .optional(),
95
+ stream_subgraphs: z.boolean().optional(),
96
+ stream_resumable: z.boolean().optional(),
97
+ after_seconds: z.number().optional(),
98
+ if_not_exists: z.enum(['create', 'reject']).optional(),
99
+ on_completion: z.enum(['complete', 'continue']).optional(),
100
+ feedback_keys: z.array(z.string()).optional(),
101
+ langsmith_tracer: z.unknown().optional(),
102
+ })
103
+ .describe('Payload for creating a stateful run.');
104
+
105
+ export const RunListQuerySchema = z.object({
106
+ limit: z.coerce.number().int().optional(),
107
+ offset: z.coerce.number().int().optional(),
108
+ status: z.enum(['pending', 'running', 'error', 'success', 'timeout', 'interrupted']).optional(),
109
+ });
110
+
111
+ export const RunCancelQuerySchema = z.object({
112
+ wait: z.coerce.boolean().optional().default(false),
113
+ action: z.enum(['interrupt', 'rollback']).optional().default('interrupt'),
114
+ });
115
+
116
+ // Threads 相关的 schema
117
+ export const ThreadCreatePayloadSchema = z
118
+ .object({
119
+ thread_id: z.string().uuid().describe('The ID of the thread. If not provided, an ID is generated.').optional(),
120
+ metadata: MetadataSchema.optional(),
121
+ if_exists: z.union([z.literal('raise'), z.literal('do_nothing')]).optional(),
122
+ })
123
+ .describe('Payload for creating a thread.');
124
+
125
+ export const ThreadSearchPayloadSchema = z
126
+ .object({
127
+ metadata: z.record(z.unknown()).describe('Metadata to search for.').optional(),
128
+ status: z.enum(['idle', 'busy', 'interrupted', 'error']).describe('Filter by thread status.').optional(),
129
+ values: z.record(z.unknown()).describe('Filter by thread values.').optional(),
130
+ limit: z.number().int().gte(1).lte(1000).describe('Maximum number to return.').optional(),
131
+ offset: z.number().int().gte(0).describe('Offset to start from.').optional(),
132
+ sort_by: z.enum(['thread_id', 'status', 'created_at', 'updated_at']).describe('Sort by field.').optional(),
133
+ sort_order: z.enum(['asc', 'desc']).describe('Sort order.').optional(),
134
+ })
135
+ .describe('Payload for listing threads.');
@@ -2,8 +2,7 @@ import { StreamEvent } from '@langchain/core/tracers/log_stream';
2
2
  import { streamState } from './graph/stream.js';
3
3
  import { Assistant, Run, StreamMode, Metadata, AssistantGraph } from '@langchain/langgraph-sdk';
4
4
  import { getGraph, GRAPHS } from './utils/getGraph.js';
5
- import { BaseThreadsManager } from './threads/index.js';
6
- import { globalMessageQueue } from './global.js';
5
+ import { LangGraphGlobal } from './global.js';
7
6
  import { AssistantSortBy, CancelAction, ILangGraphClient, RunStatus, SortOrder, StreamInputData } from './types.js';
8
7
  export { registerGraph } from './utils/getGraph.js';
9
8
 
@@ -57,16 +56,24 @@ export const AssistantEndpoint: ILangGraphClient['assistants'] = {
57
56
  },
58
57
  };
59
58
 
60
- export const createEndpoint = (threads: BaseThreadsManager): ILangGraphClient => {
59
+ export const createEndpoint = (): ILangGraphClient => {
60
+ const threads = LangGraphGlobal.globalThreadsManager;
61
61
  return {
62
62
  assistants: AssistantEndpoint,
63
63
  threads,
64
64
  runs: {
65
- list(threadId: string, options?: { limit?: number; offset?: number; status?: RunStatus }): Promise<Run[]> {
65
+ list(
66
+ threadId: string,
67
+ options?: {
68
+ limit?: number;
69
+ offset?: number;
70
+ status?: RunStatus;
71
+ },
72
+ ): Promise<Run[]> {
66
73
  return threads.listRuns(threadId, options);
67
74
  },
68
75
  async cancel(threadId: string, runId: string, wait?: boolean, action?: CancelAction): Promise<void> {
69
- return globalMessageQueue.cancelQueue(runId);
76
+ return LangGraphGlobal.globalMessageQueue.cancelQueue(runId);
70
77
  },
71
78
  async *stream(threadId: string, assistantId: string, payload: StreamInputData) {
72
79
  if (!payload.config) {
package/src/e.d.ts ADDED
@@ -0,0 +1,3 @@
1
+ declare module 'bun:sqlite' {
2
+ export * from 'better-sqlite3';
3
+ }
package/src/global.ts CHANGED
@@ -1,6 +1,11 @@
1
- import { createCheckPointer, createMessageQueue } from './storage/index.js';
2
-
3
- /** 全局队列管理器 */
4
- export const globalMessageQueue = createMessageQueue();
5
- /** 全局 Checkpointer */
6
- export const globalCheckPointer = createCheckPointer();
1
+ import { createCheckPointer, createMessageQueue, createThreadManager } from './storage/index.js';
2
+ import type { SqliteSaver } from './storage/sqlite/checkpoint.js';
3
+ const [globalMessageQueue, globalCheckPointer] = await Promise.all([createMessageQueue(), createCheckPointer()]);
4
+ const globalThreadsManager = await createThreadManager({
5
+ checkpointer: globalCheckPointer as SqliteSaver,
6
+ });
7
+ export class LangGraphGlobal {
8
+ static globalMessageQueue = globalMessageQueue;
9
+ static globalCheckPointer = globalCheckPointer;
10
+ static globalThreadsManager = globalThreadsManager;
11
+ }
@@ -4,7 +4,7 @@ import type { Pregel } from '@langchain/langgraph/pregel';
4
4
  import { getLangGraphCommand } from '../utils/getLangGraphCommand.js';
5
5
  import type { BaseStreamQueueInterface } from '../queue/stream_queue.js';
6
6
 
7
- import { globalMessageQueue } from '../global.js';
7
+ import { LangGraphGlobal } from '../global.js';
8
8
  import { Run } from '@langgraph-js/sdk';
9
9
  import { EventMessage, StreamErrorEventMessage, StreamEndEventMessage } from '../queue/event_message.js';
10
10
 
@@ -54,7 +54,13 @@ export async function streamStateWithQueue(
54
54
  libStreamMode.add('values');
55
55
  }
56
56
 
57
- await queue.push(new EventMessage('metadata', { run_id: run.run_id, attempt: options.attempt, graph_id: graphId }));
57
+ await queue.push(
58
+ new EventMessage('metadata', {
59
+ run_id: run.run_id,
60
+ attempt: options.attempt,
61
+ graph_id: graphId,
62
+ }),
63
+ );
58
64
 
59
65
  const metadata = {
60
66
  ...payload.config?.metadata,
@@ -107,7 +113,9 @@ export async function streamStateWithQueue(
107
113
  }
108
114
  }
109
115
  if (mode === 'values') {
110
- await threads.set(run.thread_id, { values: JSON.parse(serialiseAsDict(data)) });
116
+ await threads.set(run.thread_id, {
117
+ values: data ? JSON.parse(serialiseAsDict(data)) : '',
118
+ });
111
119
  }
112
120
  } else if (userStreamMode.includes('events')) {
113
121
  await queue.push(new EventMessage('events', event));
@@ -151,7 +159,9 @@ export async function streamStateWithQueue(
151
159
  if (messages[message.id] == null) {
152
160
  messages[message.id] = message;
153
161
  await queue.push(
154
- new EventMessage('messages/metadata', { [message.id]: { metadata: event.metadata } }),
162
+ new EventMessage('messages/metadata', {
163
+ [message.id]: { metadata: event.metadata },
164
+ }),
155
165
  );
156
166
  } else {
157
167
  messages[message.id] = messages[message.id].concat(message);
@@ -174,11 +184,11 @@ export async function streamStateWithQueue(
174
184
  * @returns 数据流生成器
175
185
  */
176
186
  export async function* createStreamFromQueue(queueId: string): AsyncGenerator<{ event: string; data: unknown }> {
177
- const queue = globalMessageQueue.getQueue(queueId);
187
+ const queue = LangGraphGlobal.globalMessageQueue.getQueue(queueId);
178
188
  return queue.onDataReceive();
179
189
  }
180
190
 
181
- export const serialiseAsDict = (obj: unknown) => {
191
+ export const serialiseAsDict = (obj: unknown, indent = 2) => {
182
192
  return JSON.stringify(
183
193
  obj,
184
194
  function (key: string | number, value: unknown) {
@@ -196,7 +206,7 @@ export const serialiseAsDict = (obj: unknown) => {
196
206
 
197
207
  return value;
198
208
  },
199
- 2,
209
+ indent,
200
210
  );
201
211
  };
202
212
  /**
@@ -227,12 +237,12 @@ export async function* streamState(
227
237
  // 启动队列推送任务(在后台异步执行)
228
238
  await threads.set(threadId, { status: 'busy' });
229
239
  await threads.updateRun(run.run_id, { status: 'running' });
230
- const queue = globalMessageQueue.createQueue(queueId);
240
+ const queue = LangGraphGlobal.globalMessageQueue.createQueue(queueId);
231
241
  const state = queue.onDataReceive();
232
242
  streamStateWithQueue(threads, run, queue, payload, options).catch((error) => {
233
243
  console.error('Queue task error:', error);
234
244
  // 如果生产者出错,向队列推送错误信号
235
- globalMessageQueue.pushToQueue(queueId, new StreamErrorEventMessage(error));
245
+ LangGraphGlobal.globalMessageQueue.pushToQueue(queueId, new StreamErrorEventMessage(error));
236
246
  // TODO 不知道这里需不需要错误处理
237
247
  });
238
248
  for await (const data of state) {
@@ -248,6 +258,6 @@ export async function* streamState(
248
258
  } finally {
249
259
  // 在完成后清理队列
250
260
  await threads.set(threadId, { status: 'idle' });
251
- globalMessageQueue.removeQueue(queueId);
261
+ LangGraphGlobal.globalMessageQueue.removeQueue(queueId);
252
262
  }
253
263
  }
package/src/index.ts CHANGED
@@ -3,3 +3,4 @@ export * from './types';
3
3
  export * from './global';
4
4
 
5
5
  export * from './threads/index';
6
+ export * from './utils/createEntrypointGraph';
@@ -26,7 +26,7 @@ export class BaseStreamQueue extends EventEmitter<StreamQueueEvents<EventMessage
26
26
  * Constructor
27
27
  * @param compressMessages 是否压缩消息 / Whether to compress messages
28
28
  */
29
- constructor(readonly compressMessages: boolean = true) {
29
+ constructor(readonly id: string, readonly compressMessages: boolean = true) {
30
30
  super();
31
31
  }
32
32
 
@@ -58,6 +58,7 @@ export class BaseStreamQueue extends EventEmitter<StreamQueueEvents<EventMessage
58
58
  * Base stream queue interface
59
59
  */
60
60
  export interface BaseStreamQueueInterface {
61
+ id: string;
61
62
  /** 是否压缩消息 / Whether to compress messages */
62
63
  compressMessages: boolean;
63
64
  /**
@@ -76,7 +77,7 @@ export interface BaseStreamQueueInterface {
76
77
  * @param listener 数据变化监听器 / Data change listener
77
78
  * @returns 取消监听函数 / Unsubscribe function
78
79
  */
79
- onDataChange(listener: (data: EventMessage) => void): () => void;
80
+ onDataReceive(): AsyncGenerator<EventMessage, void, unknown>;
80
81
  /** 取消信号控制器 / Cancel signal controller */
81
82
  cancelSignal: AbortController;
82
83
  /** 取消操作 / Cancel operation */
@@ -93,7 +94,7 @@ export class StreamQueueManager<Q extends BaseStreamQueueInterface> {
93
94
  /** 默认是否压缩消息 / Default compress messages setting */
94
95
  private defaultCompressMessages: boolean;
95
96
  /** 队列构造函数 / Queue constructor */
96
- private queueConstructor: new (compressMessages: boolean) => Q;
97
+ private queueConstructor: new (queueId: string) => Q;
97
98
 
98
99
  /**
99
100
  * 构造函数
@@ -102,7 +103,7 @@ export class StreamQueueManager<Q extends BaseStreamQueueInterface> {
102
103
  * @param options 配置选项 / Configuration options
103
104
  */
104
105
  constructor(
105
- queueConstructor: new (compressMessages: boolean) => Q,
106
+ queueConstructor: new (id: string) => Q,
106
107
  options: {
107
108
  /** 默认是否压缩消息 / Default compress messages setting */
108
109
  defaultCompressMessages?: boolean;
@@ -121,7 +122,7 @@ export class StreamQueueManager<Q extends BaseStreamQueueInterface> {
121
122
  */
122
123
  createQueue(id: string, compressMessages?: boolean): Q {
123
124
  const compress = compressMessages ?? this.defaultCompressMessages;
124
- this.queues.set(id, new this.queueConstructor(compress));
125
+ this.queues.set(id, new this.queueConstructor(id));
125
126
  return this.queues.get(id)!;
126
127
  }
127
128
 
@@ -1,14 +1,52 @@
1
- import { StreamQueueManager } from '../queue/stream_queue';
1
+ import { BaseStreamQueueInterface, StreamQueueManager } from '../queue/stream_queue';
2
2
  import { MemorySaver } from './memory/checkpoint';
3
3
  import { MemoryStreamQueue } from './memory/queue';
4
+ import { MemoryThreadsManager } from './memory/threads';
5
+ import type { SqliteSaver as SqliteSaverType } from './sqlite/checkpoint';
6
+ import { SQLiteThreadsManager } from './sqlite/threads';
4
7
 
5
8
  // 所有的适配实现,都请写到这里,通过环境变量进行判断使用哪种方式进行适配
6
9
 
7
- export const createCheckPointer = () => {
10
+ export const createCheckPointer = async () => {
11
+ if (
12
+ (process.env.REDIS_URL && process.env.CHECKPOINT_TYPE === 'redis') ||
13
+ process.env.CHECKPOINT_TYPE === 'shallow/redis'
14
+ ) {
15
+ if (process.env.CHECKPOINT_TYPE === 'redis') {
16
+ const { RedisSaver } = await import('@langchain/langgraph-checkpoint-redis');
17
+ return await RedisSaver.fromUrl(process.env.REDIS_URL!, {
18
+ defaultTTL: 60, // TTL in minutes
19
+ refreshOnRead: true,
20
+ });
21
+ }
22
+ if (process.env.CHECKPOINT_TYPE === 'shallow/redis') {
23
+ const { ShallowRedisSaver } = await import('@langchain/langgraph-checkpoint-redis/shallow');
24
+ return await ShallowRedisSaver.fromUrl(process.env.REDIS_URL!);
25
+ }
26
+ }
27
+ if (process.env.SQLITE_DATABASE_URI) {
28
+ const { SqliteSaver } = await import('./sqlite/checkpoint');
29
+ const db = SqliteSaver.fromConnString(process.env.SQLITE_DATABASE_URI);
30
+ return db;
31
+ }
8
32
  return new MemorySaver();
9
33
  };
10
34
 
11
- export const createMessageQueue = () => {
12
- const q: new (compressMessages: boolean) => MemoryStreamQueue = MemoryStreamQueue;
35
+ export const createMessageQueue = async () => {
36
+ let q: new (id: string) => BaseStreamQueueInterface;
37
+ if (process.env.REDIS_URL) {
38
+ console.log('Using redis as stream queue');
39
+ const { RedisStreamQueue } = await import('./redis/queue');
40
+ q = RedisStreamQueue;
41
+ } else {
42
+ q = MemoryStreamQueue;
43
+ }
13
44
  return new StreamQueueManager(q);
14
45
  };
46
+
47
+ export const createThreadManager = (config: { checkpointer?: SqliteSaverType }) => {
48
+ if (process.env.SQLITE_DATABASE_URI && config.checkpointer) {
49
+ return new SQLiteThreadsManager(config.checkpointer);
50
+ }
51
+ return new MemoryThreadsManager();
52
+ };
@@ -0,0 +1,148 @@
1
+ import { CancelEventMessage, EventMessage } from '../../queue/event_message.js';
2
+ import { BaseStreamQueue } from '../../queue/stream_queue.js';
3
+ import { BaseStreamQueueInterface } from '../../queue/stream_queue.js';
4
+ import { createClient, RedisClientType } from 'redis';
5
+
6
+ /**
7
+ * Redis 实现的消息队列,用于存储消息
8
+ */
9
+ export class RedisStreamQueue extends BaseStreamQueue implements BaseStreamQueueInterface {
10
+ static redis: RedisClientType = createClient({ url: process.env.REDIS_URL! });
11
+ static subscriberRedis: RedisClientType = createClient({ url: process.env.REDIS_URL! });
12
+ private redis: RedisClientType;
13
+ private subscriberRedis: RedisClientType;
14
+ private queueKey: string;
15
+ private channelKey: string;
16
+ private isConnected = false;
17
+ public cancelSignal: AbortController;
18
+
19
+ constructor(readonly id: string = 'default') {
20
+ super(id, true);
21
+ this.queueKey = `queue:${this.id}`;
22
+ this.channelKey = `channel:${this.id}`;
23
+ this.redis = RedisStreamQueue.redis;
24
+ this.subscriberRedis = RedisStreamQueue.subscriberRedis;
25
+ this.cancelSignal = new AbortController();
26
+
27
+ // 连接 Redis 客户端
28
+ this.redis.connect();
29
+ this.subscriberRedis.connect();
30
+ this.isConnected = true;
31
+ }
32
+
33
+ /**
34
+ * 推送消息到 Redis 队列
35
+ */
36
+ async push(item: EventMessage): Promise<void> {
37
+ const data = await this.encodeData(item);
38
+ const serializedData = Buffer.from(data);
39
+
40
+ // 推送到队列
41
+ await this.redis.lPush(this.queueKey, serializedData);
42
+
43
+ // 设置队列 TTL 为 300 秒
44
+ await this.redis.expire(this.queueKey, 300);
45
+
46
+ // 发布到频道通知有新数据
47
+ await this.redis.publish(this.channelKey, serializedData);
48
+
49
+ this.emit('dataChange', data);
50
+ }
51
+
52
+ /**
53
+ * 异步生成器:支持 for await...of 方式消费队列数据
54
+ */
55
+ async *onDataReceive(): AsyncGenerator<EventMessage, void, unknown> {
56
+ let queue: EventMessage[] = [];
57
+ let pendingResolve: (() => void) | null = null;
58
+ let isStreamEnded = false;
59
+ const handleMessage = async (message: string) => {
60
+ const data = (await this.decodeData(message)) as EventMessage;
61
+ queue.push(data);
62
+ // 检查是否为流结束或错误信号
63
+ if (
64
+ data.event === '__stream_end__' ||
65
+ data.event === '__stream_error__' ||
66
+ data.event === '__stream_cancel__'
67
+ ) {
68
+ setTimeout(() => {
69
+ isStreamEnded = true;
70
+ if (pendingResolve) {
71
+ pendingResolve();
72
+ pendingResolve = null;
73
+ }
74
+ }, 300);
75
+
76
+ if (data.event === '__stream_cancel__') {
77
+ this.cancel();
78
+ }
79
+ }
80
+
81
+ if (pendingResolve) {
82
+ pendingResolve();
83
+ pendingResolve = null;
84
+ }
85
+ };
86
+
87
+ // 订阅 Redis 频道
88
+ await this.subscriberRedis.subscribe(this.channelKey, (message) => {
89
+ handleMessage(message);
90
+ });
91
+
92
+ try {
93
+ while (!isStreamEnded) {
94
+ if (queue.length > 0) {
95
+ for (const item of queue) {
96
+ yield item;
97
+ }
98
+ queue = [];
99
+ } else {
100
+ await new Promise((resolve) => {
101
+ pendingResolve = resolve as () => void;
102
+ });
103
+ }
104
+ }
105
+ } finally {
106
+ await this.subscriberRedis.unsubscribe(this.channelKey);
107
+ }
108
+ }
109
+
110
+ /**
111
+ * 获取队列中的所有数据
112
+ */
113
+ async getAll(): Promise<EventMessage[]> {
114
+ const data = await this.redis.lRange(this.queueKey, 0, -1);
115
+
116
+ if (!data || data.length === 0) {
117
+ return [];
118
+ }
119
+
120
+ if (this.compressMessages) {
121
+ return (await Promise.all(
122
+ data.map(async (item: string) => {
123
+ const parsed = JSON.parse(item) as EventMessage;
124
+ return (await this.decodeData(parsed as any)) as EventMessage;
125
+ }),
126
+ )) as EventMessage[];
127
+ } else {
128
+ return data.map((item: string) => JSON.parse(item) as EventMessage);
129
+ }
130
+ }
131
+
132
+ /**
133
+ * 清空队列
134
+ */
135
+ clear(): void {
136
+ if (this.isConnected) {
137
+ this.redis.del(this.queueKey);
138
+ }
139
+ }
140
+
141
+ /**
142
+ * 取消操作
143
+ */
144
+ cancel(): void {
145
+ this.push(new CancelEventMessage());
146
+ this.cancelSignal.abort('user cancel this run');
147
+ }
148
+ }
@@ -0,0 +1,16 @@
1
+ import { DatabaseType } from './type';
2
+
3
+ let Database: new (uri: string) => DatabaseType;
4
+ /** @ts-ignore */
5
+ if (globalThis.Bun) {
6
+ console.log('Using Bun Sqlite, pid:', process.pid);
7
+ const BunSqlite = await import('bun:sqlite');
8
+ /** @ts-ignore */
9
+ Database = BunSqlite.default;
10
+ } else {
11
+ /** @ts-ignore */
12
+ const CommonSqlite = await import('better-sqlite3');
13
+ Database = CommonSqlite.default;
14
+ }
15
+
16
+ export { Database };