@zhin.js/ai 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/session.ts ADDED
@@ -0,0 +1,537 @@
1
+ /**
2
+ * @zhin.js/ai - Session Manager
3
+ * 会话管理器,支持上下文记忆和数据库持久化
4
+ *
5
+ * 特性:
6
+ * - 数据库持久化存储(使用 Zhin 的数据库服务)
7
+ * - 内存缓存加速读取
8
+ * - 自动过期清理
9
+ * - 更长的上下文记忆能力
10
+ */
11
+
12
+ import { Logger } from '@zhin.js/core';
13
+ import type { ChatMessage, SessionConfig, Session } from './types.js';
14
+
15
+ const logger = new Logger(null, 'AI-Session');
16
+
17
+ /**
18
+ * 数据库模型定义
19
+ */
20
+ export const AI_SESSION_MODEL = {
21
+ session_id: { type: 'text' as const, nullable: false },
22
+ messages: { type: 'json' as const, default: [] },
23
+ config: { type: 'json' as const, default: {} },
24
+ created_at: { type: 'integer' as const, default: 0 },
25
+ updated_at: { type: 'integer' as const, default: 0 },
26
+ };
27
+
28
+ /**
29
+ * 数据库会话记录
30
+ */
31
+ interface SessionRecord {
32
+ id?: number;
33
+ session_id: string;
34
+ messages: ChatMessage[];
35
+ config: SessionConfig;
36
+ created_at: number;
37
+ updated_at: number;
38
+ }
39
+
40
+ /**
41
+ * 会话管理器接口
42
+ */
43
+ export interface ISessionManager {
44
+ get(sessionId: string, config?: SessionConfig): Session | Promise<Session>;
45
+ has(sessionId: string): boolean | Promise<boolean>;
46
+ addMessage(sessionId: string, message: ChatMessage): void | Promise<void>;
47
+ getMessages(sessionId: string): ChatMessage[] | Promise<ChatMessage[]>;
48
+ setSystemPrompt(sessionId: string, prompt: string): void | Promise<void>;
49
+ clear(sessionId: string): boolean | Promise<boolean>;
50
+ reset(sessionId: string): void | Promise<void>;
51
+ listSessions(): string[] | Promise<string[]>;
52
+ getStats(): { total: number; active: number; expired: number } | Promise<{ total: number; active: number; expired: number }>;
53
+ cleanup(): number | Promise<number>;
54
+ dispose(): void | Promise<void>;
55
+ }
56
+
57
+ /**
58
+ * 内存会话管理器(回退方案)
59
+ */
60
+ export class MemorySessionManager implements ISessionManager {
61
+ private sessions: Map<string, Session> = new Map();
62
+ private config: Required<Pick<SessionConfig, 'maxHistory' | 'expireMs'>>;
63
+ private cleanupTimer?: ReturnType<typeof setInterval>;
64
+
65
+ constructor(config: { maxHistory?: number; expireMs?: number } = {}) {
66
+ this.config = {
67
+ maxHistory: config.maxHistory ?? 100,
68
+ expireMs: config.expireMs ?? 24 * 60 * 60 * 1000, // 24 小时
69
+ };
70
+
71
+ // 定期清理过期会话
72
+ this.cleanupTimer = setInterval(() => this.cleanup(), 5 * 60 * 1000);
73
+ }
74
+
75
+ get(sessionId: string, config?: SessionConfig): Session {
76
+ let session = this.sessions.get(sessionId);
77
+
78
+ if (!session) {
79
+ session = {
80
+ id: sessionId,
81
+ config: config || { provider: 'openai' },
82
+ messages: [],
83
+ createdAt: Date.now(),
84
+ updatedAt: Date.now(),
85
+ };
86
+ this.sessions.set(sessionId, session);
87
+ } else {
88
+ session.updatedAt = Date.now();
89
+ }
90
+
91
+ return session;
92
+ }
93
+
94
+ has(sessionId: string): boolean {
95
+ return this.sessions.has(sessionId);
96
+ }
97
+
98
+ addMessage(sessionId: string, message: ChatMessage): void {
99
+ const session = this.get(sessionId);
100
+ session.messages.push(message);
101
+ session.updatedAt = Date.now();
102
+ this.trimMessages(session);
103
+ }
104
+
105
+ private trimMessages(session: Session): void {
106
+ const maxHistory = session.config.maxHistory ?? this.config.maxHistory;
107
+ if (session.messages.length > maxHistory) {
108
+ const systemMessages = session.messages.filter(m => m.role === 'system');
109
+ const otherMessages = session.messages.filter(m => m.role !== 'system');
110
+ const keepCount = maxHistory - systemMessages.length;
111
+ session.messages = [...systemMessages, ...otherMessages.slice(-keepCount)];
112
+ }
113
+ }
114
+
115
+ getMessages(sessionId: string): ChatMessage[] {
116
+ return this.sessions.get(sessionId)?.messages || [];
117
+ }
118
+
119
+ setSystemPrompt(sessionId: string, prompt: string): void {
120
+ const session = this.get(sessionId);
121
+ session.messages = session.messages.filter(m => m.role !== 'system');
122
+ session.messages.unshift({ role: 'system', content: prompt });
123
+ session.updatedAt = Date.now();
124
+ }
125
+
126
+ clear(sessionId: string): boolean {
127
+ return this.sessions.delete(sessionId);
128
+ }
129
+
130
+ reset(sessionId: string): void {
131
+ const session = this.sessions.get(sessionId);
132
+ if (session) {
133
+ const systemMessages = session.messages.filter(m => m.role === 'system');
134
+ session.messages = systemMessages;
135
+ session.updatedAt = Date.now();
136
+ }
137
+ }
138
+
139
+ listSessions(): string[] {
140
+ return Array.from(this.sessions.keys());
141
+ }
142
+
143
+ getStats(): { total: number; active: number; expired: number } {
144
+ const now = Date.now();
145
+ let active = 0;
146
+ let expired = 0;
147
+
148
+ for (const session of this.sessions.values()) {
149
+ if (now - session.updatedAt > this.config.expireMs) {
150
+ expired++;
151
+ } else {
152
+ active++;
153
+ }
154
+ }
155
+
156
+ return { total: this.sessions.size, active, expired };
157
+ }
158
+
159
+ cleanup(): number {
160
+ const now = Date.now();
161
+ let cleaned = 0;
162
+
163
+ for (const [id, session] of this.sessions) {
164
+ const expireMs = session.config.expireMs ?? this.config.expireMs;
165
+ if (now - session.updatedAt > expireMs) {
166
+ this.sessions.delete(id);
167
+ cleaned++;
168
+ }
169
+ }
170
+
171
+ return cleaned;
172
+ }
173
+
174
+ dispose(): void {
175
+ if (this.cleanupTimer) {
176
+ clearInterval(this.cleanupTimer);
177
+ this.cleanupTimer = undefined;
178
+ }
179
+ this.sessions.clear();
180
+ }
181
+ }
182
+
183
+ /**
184
+ * 数据库会话管理器
185
+ * 使用 Zhin 的数据库服务进行持久化存储
186
+ */
187
+ export class DatabaseSessionManager implements ISessionManager {
188
+ private cache: Map<string, Session> = new Map();
189
+ private config: Required<Pick<SessionConfig, 'maxHistory' | 'expireMs'>>;
190
+ private cleanupTimer?: ReturnType<typeof setInterval>;
191
+ private saveQueue: Map<string, Session> = new Map();
192
+ private saveTimer?: ReturnType<typeof setTimeout>;
193
+ private model: any; // 数据库模型
194
+
195
+ constructor(
196
+ model: any,
197
+ config: { maxHistory?: number; expireMs?: number } = {}
198
+ ) {
199
+ this.model = model;
200
+ this.config = {
201
+ maxHistory: config.maxHistory ?? 200, // 数据库支持更长的历史
202
+ expireMs: config.expireMs ?? 7 * 24 * 60 * 60 * 1000, // 7 天过期
203
+ };
204
+
205
+ // 定期清理过期会话(每小时)
206
+ this.cleanupTimer = setInterval(() => this.cleanup(), 60 * 60 * 1000);
207
+ }
208
+
209
+ /**
210
+ * 从数据库加载会话
211
+ */
212
+ private async loadSession(sessionId: string): Promise<Session | null> {
213
+ try {
214
+ const records = await this.model.select({ session_id: sessionId });
215
+ if (records && records.length > 0) {
216
+ const record = records[0] as SessionRecord;
217
+ return {
218
+ id: record.session_id,
219
+ config: record.config || { provider: 'openai' },
220
+ messages: record.messages || [],
221
+ createdAt: record.created_at,
222
+ updatedAt: record.updated_at,
223
+ };
224
+ }
225
+ } catch (error) {
226
+ logger.debug('加载会话失败:', error);
227
+ }
228
+ return null;
229
+ }
230
+
231
+ /**
232
+ * 保存会话到数据库(防抖)
233
+ */
234
+ private schedulesSave(session: Session): void {
235
+ this.saveQueue.set(session.id, session);
236
+
237
+ if (!this.saveTimer) {
238
+ this.saveTimer = setTimeout(() => this.flushSaveQueue(), 1000);
239
+ }
240
+ }
241
+
242
+ /**
243
+ * 批量保存队列中的会话
244
+ */
245
+ private async flushSaveQueue(): Promise<void> {
246
+ this.saveTimer = undefined;
247
+
248
+ const sessions = Array.from(this.saveQueue.values());
249
+ this.saveQueue.clear();
250
+
251
+ for (const session of sessions) {
252
+ try {
253
+ const existing = await this.model.select({ session_id: session.id });
254
+ const record: Partial<SessionRecord> = {
255
+ session_id: session.id,
256
+ messages: session.messages,
257
+ config: session.config,
258
+ updated_at: session.updatedAt,
259
+ };
260
+
261
+ if (existing && existing.length > 0) {
262
+ await this.model.update(record, { session_id: session.id });
263
+ } else {
264
+ record.created_at = session.createdAt;
265
+ await this.model.create(record);
266
+ }
267
+ } catch (error) {
268
+ logger.debug(`保存会话 ${session.id} 失败:`, error);
269
+ }
270
+ }
271
+ }
272
+
273
+ async get(sessionId: string, config?: SessionConfig): Promise<Session> {
274
+ // 先检查缓存
275
+ let session = this.cache.get(sessionId);
276
+
277
+ if (!session) {
278
+ // 从数据库加载
279
+ session = await this.loadSession(sessionId) ?? undefined;
280
+
281
+ if (!session) {
282
+ // 创建新会话
283
+ session = {
284
+ id: sessionId,
285
+ config: config || { provider: 'openai' },
286
+ messages: [],
287
+ createdAt: Date.now(),
288
+ updatedAt: Date.now(),
289
+ };
290
+ }
291
+
292
+ this.cache.set(sessionId, session);
293
+ }
294
+
295
+ session.updatedAt = Date.now();
296
+ this.schedulesSave(session);
297
+
298
+ return session;
299
+ }
300
+
301
+ async has(sessionId: string): Promise<boolean> {
302
+ if (this.cache.has(sessionId)) {
303
+ return true;
304
+ }
305
+
306
+ try {
307
+ const records = await this.model.select({ session_id: sessionId });
308
+ return records && records.length > 0;
309
+ } catch {
310
+ return false;
311
+ }
312
+ }
313
+
314
+ async addMessage(sessionId: string, message: ChatMessage): Promise<void> {
315
+ const session = await this.get(sessionId);
316
+ session.messages.push(message);
317
+ session.updatedAt = Date.now();
318
+ this.trimMessages(session);
319
+ this.schedulesSave(session);
320
+ }
321
+
322
+ private trimMessages(session: Session): void {
323
+ const maxHistory = session.config.maxHistory ?? this.config.maxHistory;
324
+ if (session.messages.length > maxHistory) {
325
+ const systemMessages = session.messages.filter(m => m.role === 'system');
326
+ const otherMessages = session.messages.filter(m => m.role !== 'system');
327
+ const keepCount = maxHistory - systemMessages.length;
328
+ session.messages = [...systemMessages, ...otherMessages.slice(-keepCount)];
329
+ }
330
+ }
331
+
332
+ async getMessages(sessionId: string): Promise<ChatMessage[]> {
333
+ const session = await this.get(sessionId);
334
+ return session.messages;
335
+ }
336
+
337
+ async setSystemPrompt(sessionId: string, prompt: string): Promise<void> {
338
+ const session = await this.get(sessionId);
339
+ session.messages = session.messages.filter(m => m.role !== 'system');
340
+ session.messages.unshift({ role: 'system', content: prompt });
341
+ session.updatedAt = Date.now();
342
+ this.schedulesSave(session);
343
+ }
344
+
345
+ async clear(sessionId: string): Promise<boolean> {
346
+ this.cache.delete(sessionId);
347
+ this.saveQueue.delete(sessionId);
348
+
349
+ try {
350
+ await this.model.delete({ session_id: sessionId });
351
+ return true;
352
+ } catch (error) {
353
+ logger.debug(`删除会话 ${sessionId} 失败:`, error);
354
+ return false;
355
+ }
356
+ }
357
+
358
+ async reset(sessionId: string): Promise<void> {
359
+ const session = await this.get(sessionId);
360
+ const systemMessages = session.messages.filter(m => m.role === 'system');
361
+ session.messages = systemMessages;
362
+ session.updatedAt = Date.now();
363
+ this.schedulesSave(session);
364
+ }
365
+
366
+ async listSessions(): Promise<string[]> {
367
+ try {
368
+ const records = await this.model.select();
369
+ return (records as SessionRecord[]).map(r => r.session_id);
370
+ } catch (error) {
371
+ logger.debug('列出会话失败:', error);
372
+ return Array.from(this.cache.keys());
373
+ }
374
+ }
375
+
376
+ async getStats(): Promise<{ total: number; active: number; expired: number }> {
377
+ const now = Date.now();
378
+ let total = 0;
379
+ let active = 0;
380
+ let expired = 0;
381
+
382
+ try {
383
+ const records = await this.model.select() as SessionRecord[];
384
+ total = records.length;
385
+
386
+ for (const record of records) {
387
+ const expireMs = record.config?.expireMs ?? this.config.expireMs;
388
+ if (now - record.updated_at > expireMs) {
389
+ expired++;
390
+ } else {
391
+ active++;
392
+ }
393
+ }
394
+ } catch (error) {
395
+ logger.debug('获取统计失败:', error);
396
+ // 回退到缓存统计
397
+ total = this.cache.size;
398
+ for (const session of this.cache.values()) {
399
+ if (now - session.updatedAt > this.config.expireMs) {
400
+ expired++;
401
+ } else {
402
+ active++;
403
+ }
404
+ }
405
+ }
406
+
407
+ return { total, active, expired };
408
+ }
409
+
410
+ async cleanup(): Promise<number> {
411
+ const now = Date.now();
412
+ let cleaned = 0;
413
+
414
+ try {
415
+ const records = await this.model.select() as SessionRecord[];
416
+
417
+ for (const record of records) {
418
+ const expireMs = record.config?.expireMs ?? this.config.expireMs;
419
+ if (now - record.updated_at > expireMs) {
420
+ await this.model.delete({ session_id: record.session_id });
421
+ this.cache.delete(record.session_id);
422
+ cleaned++;
423
+ }
424
+ }
425
+ } catch (error) {
426
+ logger.debug('清理会话失败:', error);
427
+ }
428
+
429
+ return cleaned;
430
+ }
431
+
432
+ async dispose(): Promise<void> {
433
+ // 保存所有待保存的会话
434
+ if (this.saveTimer) {
435
+ clearTimeout(this.saveTimer);
436
+ this.saveTimer = undefined;
437
+ }
438
+ await this.flushSaveQueue();
439
+
440
+ if (this.cleanupTimer) {
441
+ clearInterval(this.cleanupTimer);
442
+ this.cleanupTimer = undefined;
443
+ }
444
+
445
+ this.cache.clear();
446
+ }
447
+ }
448
+
449
+ /**
450
+ * 会话管理器包装器
451
+ * 支持同步和异步接口的统一使用
452
+ */
453
+ export class SessionManager implements ISessionManager {
454
+ private manager: ISessionManager;
455
+
456
+ constructor(manager: ISessionManager) {
457
+ this.manager = manager;
458
+ }
459
+
460
+ /**
461
+ * 生成会话 ID
462
+ */
463
+ static generateId(platform: string, userId: string, channelId?: string): string {
464
+ return channelId
465
+ ? `${platform}:${channelId}:${userId}`
466
+ : `${platform}:${userId}`;
467
+ }
468
+
469
+ get(sessionId: string, config?: SessionConfig): Session | Promise<Session> {
470
+ return this.manager.get(sessionId, config);
471
+ }
472
+
473
+ has(sessionId: string): boolean | Promise<boolean> {
474
+ return this.manager.has(sessionId);
475
+ }
476
+
477
+ addMessage(sessionId: string, message: ChatMessage): void | Promise<void> {
478
+ return this.manager.addMessage(sessionId, message);
479
+ }
480
+
481
+ getMessages(sessionId: string): ChatMessage[] | Promise<ChatMessage[]> {
482
+ return this.manager.getMessages(sessionId);
483
+ }
484
+
485
+ setSystemPrompt(sessionId: string, prompt: string): void | Promise<void> {
486
+ return this.manager.setSystemPrompt(sessionId, prompt);
487
+ }
488
+
489
+ clear(sessionId: string): boolean | Promise<boolean> {
490
+ return this.manager.clear(sessionId);
491
+ }
492
+
493
+ reset(sessionId: string): void | Promise<void> {
494
+ return this.manager.reset(sessionId);
495
+ }
496
+
497
+ listSessions(): string[] | Promise<string[]> {
498
+ return this.manager.listSessions();
499
+ }
500
+
501
+ getStats(): { total: number; active: number; expired: number } | Promise<{ total: number; active: number; expired: number }> {
502
+ return this.manager.getStats();
503
+ }
504
+
505
+ cleanup(): number | Promise<number> {
506
+ return this.manager.cleanup();
507
+ }
508
+
509
+ dispose(): void | Promise<void> {
510
+ return this.manager.dispose();
511
+ }
512
+ }
513
+
514
+ /**
515
+ * 创建内存会话管理器(回退方案)
516
+ */
517
+ export function createMemorySessionManager(config?: { maxHistory?: number; expireMs?: number }): SessionManager {
518
+ return new SessionManager(new MemorySessionManager(config));
519
+ }
520
+
521
+ /**
522
+ * 创建数据库会话管理器
523
+ */
524
+ export function createDatabaseSessionManager(
525
+ model: any,
526
+ config?: { maxHistory?: number; expireMs?: number }
527
+ ): SessionManager {
528
+ return new SessionManager(new DatabaseSessionManager(model, config));
529
+ }
530
+
531
+ /**
532
+ * 创建会话管理器(向后兼容)
533
+ * @deprecated 使用 createMemorySessionManager 或 createDatabaseSessionManager
534
+ */
535
+ export function createSessionManager(config?: { maxHistory?: number; expireMs?: number }): SessionManager {
536
+ return createMemorySessionManager(config);
537
+ }