@lobehub/chat 0.162.25 → 0.163.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. package/.github/workflows/release.yml +21 -2
  2. package/.github/workflows/sync.yml +1 -1
  3. package/.github/workflows/test.yml +35 -4
  4. package/CHANGELOG.md +25 -0
  5. package/LICENSE +38 -21
  6. package/codecov.yml +11 -0
  7. package/drizzle.config.ts +29 -0
  8. package/next.config.mjs +3 -0
  9. package/package.json +24 -4
  10. package/scripts/migrateServerDB/index.ts +30 -0
  11. package/src/app/(main)/(mobile)/me/(home)/features/useCategory.tsx +2 -1
  12. package/src/app/(main)/chat/@session/features/SessionListContent/List/Item/Actions.tsx +95 -88
  13. package/src/app/(main)/chat/settings/features/HeaderContent.tsx +37 -31
  14. package/src/app/api/webhooks/clerk/__tests__/fixtures/createUser.json +73 -0
  15. package/src/app/api/webhooks/clerk/route.ts +159 -0
  16. package/src/app/api/webhooks/clerk/validateRequest.ts +22 -0
  17. package/src/app/trpc/edge/[trpc]/route.ts +1 -1
  18. package/src/app/trpc/lambda/[trpc]/route.ts +26 -0
  19. package/src/config/auth.ts +2 -0
  20. package/src/config/db.ts +13 -1
  21. package/src/database/server/core/db.ts +44 -0
  22. package/src/database/server/core/dbForTest.ts +45 -0
  23. package/src/database/server/index.ts +1 -0
  24. package/src/database/server/migrations/0000_init.sql +439 -0
  25. package/src/database/server/migrations/0001_add_client_id.sql +9 -0
  26. package/src/database/server/migrations/0002_amusing_puma.sql +9 -0
  27. package/src/database/server/migrations/meta/0000_snapshot.json +1583 -0
  28. package/src/database/server/migrations/meta/0001_snapshot.json +1636 -0
  29. package/src/database/server/migrations/meta/0002_snapshot.json +1630 -0
  30. package/src/database/server/migrations/meta/_journal.json +27 -0
  31. package/src/database/server/models/__tests__/file.test.ts +140 -0
  32. package/src/database/server/models/__tests__/message.test.ts +847 -0
  33. package/src/database/server/models/__tests__/plugin.test.ts +172 -0
  34. package/src/database/server/models/__tests__/session.test.ts +595 -0
  35. package/src/database/server/models/__tests__/topic.test.ts +623 -0
  36. package/src/database/server/models/__tests__/user.test.ts +173 -0
  37. package/src/database/server/models/_template.ts +44 -0
  38. package/src/database/server/models/file.ts +51 -0
  39. package/src/database/server/models/message.ts +378 -0
  40. package/src/database/server/models/plugin.ts +63 -0
  41. package/src/database/server/models/session.ts +290 -0
  42. package/src/database/server/models/sessionGroup.ts +69 -0
  43. package/src/database/server/models/topic.ts +265 -0
  44. package/src/database/server/models/user.ts +138 -0
  45. package/src/database/server/modules/DataImporter/__tests__/fixtures/messages.json +1101 -0
  46. package/src/database/server/modules/DataImporter/__tests__/index.test.ts +954 -0
  47. package/src/database/server/modules/DataImporter/index.ts +333 -0
  48. package/src/database/server/schemas/_id.ts +15 -0
  49. package/src/database/server/schemas/lobechat.ts +601 -0
  50. package/src/database/server/utils/idGenerator.test.ts +39 -0
  51. package/src/database/server/utils/idGenerator.ts +26 -0
  52. package/src/features/User/UserPanel/useMenu.tsx +43 -37
  53. package/src/libs/trpc/client.ts +52 -3
  54. package/src/server/files/s3.ts +21 -1
  55. package/src/server/keyVaultsEncrypt/index.test.ts +62 -0
  56. package/src/server/keyVaultsEncrypt/index.ts +93 -0
  57. package/src/server/mock.ts +1 -1
  58. package/src/server/routers/{index.ts → edge/index.ts} +3 -3
  59. package/src/server/routers/lambda/file.ts +49 -0
  60. package/src/server/routers/lambda/importer.ts +54 -0
  61. package/src/server/routers/lambda/index.ts +28 -0
  62. package/src/server/routers/lambda/message.ts +165 -0
  63. package/src/server/routers/lambda/plugin.ts +100 -0
  64. package/src/server/routers/lambda/session.ts +194 -0
  65. package/src/server/routers/lambda/sessionGroup.ts +77 -0
  66. package/src/server/routers/lambda/topic.ts +134 -0
  67. package/src/server/routers/lambda/user.ts +57 -0
  68. package/src/services/file/index.ts +4 -7
  69. package/src/services/file/server.ts +45 -0
  70. package/src/services/import/index.ts +4 -1
  71. package/src/services/import/server.ts +115 -0
  72. package/src/services/message/index.ts +4 -8
  73. package/src/services/message/server.ts +93 -0
  74. package/src/services/plugin/index.ts +4 -9
  75. package/src/services/plugin/server.ts +46 -0
  76. package/src/services/session/index.ts +4 -8
  77. package/src/services/session/server.ts +148 -0
  78. package/src/services/topic/index.ts +4 -9
  79. package/src/services/topic/server.ts +68 -0
  80. package/src/services/user/index.ts +4 -9
  81. package/src/services/user/server.ts +28 -0
  82. package/tests/setup-db.ts +7 -0
  83. package/vitest.config.ts +2 -1
  84. package/vitest.server.config.ts +23 -0
@@ -0,0 +1,623 @@
1
+ import { eq, inArray } from 'drizzle-orm';
2
+ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
3
+
4
+ import { getTestDBInstance } from '@/database/server/core/dbForTest';
5
+
6
+ import { messages, sessions, topics, users } from '../../schemas/lobechat';
7
+ import { CreateTopicParams, TopicModel } from '../topic';
8
+
9
+ let serverDB = await getTestDBInstance();
10
+
11
+ vi.mock('@/database/server/core/db', async () => ({
12
+ get serverDB() {
13
+ return serverDB;
14
+ },
15
+ }));
16
+
17
+ const userId = 'topic-user-test';
18
+ const sessionId = 'topic-session';
19
+ const topicModel = new TopicModel(userId);
20
+
21
+ describe('TopicModel', () => {
22
+ beforeEach(async () => {
23
+ await serverDB.delete(users);
24
+
25
+ // 创建测试数据
26
+ await serverDB.transaction(async (tx) => {
27
+ await tx.insert(users).values({ id: userId });
28
+ await tx.insert(sessions).values({ id: sessionId, userId });
29
+ });
30
+ });
31
+
32
+ afterEach(async () => {
33
+ // 在每个测试用例之后,清空表
34
+ await serverDB.delete(users);
35
+ });
36
+
37
+ describe('query', () => {
38
+ it('should query topics by user ID', async () => {
39
+ // 创建一些测试数据
40
+ await serverDB.transaction(async (tx) => {
41
+ await tx.insert(users).values([{ id: '456' }]);
42
+
43
+ await tx.insert(topics).values([
44
+ { id: '1', userId, sessionId, updatedAt: new Date('2023-01-01') },
45
+ { id: '4', userId, sessionId, updatedAt: new Date('2023-03-01') },
46
+ { id: '2', userId, sessionId, updatedAt: new Date('2023-02-01'), favorite: true },
47
+ { id: '5', userId, sessionId, updatedAt: new Date('2023-05-01'), favorite: true },
48
+ { id: '3', userId: '456', sessionId, updatedAt: new Date('2023-03-01') },
49
+ ]);
50
+ });
51
+
52
+ // 调用 query 方法
53
+ const result = await topicModel.query({ sessionId });
54
+
55
+ // 断言结果
56
+ expect(result).toHaveLength(4);
57
+ expect(result[0].id).toBe('5'); // favorite 的 topic 应该在前面,按照 updatedAt 降序排序
58
+ expect(result[1].id).toBe('2');
59
+ expect(result[2].id).toBe('4'); // 按照 updatedAt 降序排序
60
+ });
61
+
62
+ it('should query topics with pagination', async () => {
63
+ // 创建测试数据
64
+ await serverDB.insert(topics).values([
65
+ { id: '1', userId, updatedAt: new Date('2023-01-01') },
66
+ { id: '2', userId, updatedAt: new Date('2023-02-01') },
67
+ { id: '3', userId, updatedAt: new Date('2023-03-01') },
68
+ ]);
69
+
70
+ // 应该返回 2 个 topics
71
+ const result1 = await topicModel.query({ current: 0, pageSize: 2 });
72
+ expect(result1).toHaveLength(2);
73
+
74
+ // 应该只返回 1 个 topic,并且是第 2 个
75
+ const result2 = await topicModel.query({ current: 1, pageSize: 1 });
76
+ expect(result2).toHaveLength(1);
77
+ expect(result2[0].id).toBe('2');
78
+ });
79
+
80
+ it('should query topics by session ID', async () => {
81
+ // 创建测试数据
82
+ await serverDB.transaction(async (tx) => {
83
+ await tx.insert(sessions).values([
84
+ { id: 'session1', userId },
85
+ { id: 'session2', userId },
86
+ ]);
87
+
88
+ await tx.insert(topics).values([
89
+ { id: '1', userId, sessionId: 'session1' },
90
+ { id: '2', userId, sessionId: 'session2' },
91
+ { id: '3', userId }, // 没有 sessionId
92
+ ]);
93
+ });
94
+
95
+ // 应该只返回属于 session1 的 topic
96
+ const result = await topicModel.query({ sessionId: 'session1' });
97
+ expect(result).toHaveLength(1);
98
+ expect(result[0].id).toBe('1');
99
+ });
100
+
101
+ it('should return topics based on pagination parameters', async () => {
102
+ // 创建测试数据
103
+ await serverDB.insert(topics).values([
104
+ { id: 'topic1', sessionId, userId, updatedAt: new Date('2023-01-01') },
105
+ { id: 'topic2', sessionId, userId, updatedAt: new Date('2023-01-02') },
106
+ { id: 'topic3', sessionId, userId, updatedAt: new Date('2023-01-03') },
107
+ ]);
108
+
109
+ // 调用 query 方法
110
+ const result1 = await topicModel.query({ current: 0, pageSize: 2, sessionId });
111
+ const result2 = await topicModel.query({ current: 1, pageSize: 2, sessionId });
112
+
113
+ // 断言返回结果符合分页要求
114
+ expect(result1).toHaveLength(2);
115
+ expect(result1[0].id).toBe('topic3');
116
+ expect(result1[1].id).toBe('topic2');
117
+
118
+ expect(result2).toHaveLength(1);
119
+ expect(result2[0].id).toBe('topic1');
120
+ });
121
+ });
122
+
123
+ describe('findById', () => {
124
+ it('should return a topic by id', async () => {
125
+ // 创建测试数据
126
+ await serverDB.insert(topics).values({ id: 'topic1', sessionId, userId });
127
+
128
+ // 调用 findById 方法
129
+ const result = await topicModel.findById('topic1');
130
+
131
+ // 断言返回结果符合预期
132
+ expect(result?.id).toBe('topic1');
133
+ });
134
+
135
+ it('should return undefined for non-existent topic', async () => {
136
+ // 调用 findById 方法
137
+ const result = await topicModel.findById('non-existent');
138
+
139
+ // 断言返回 undefined
140
+ expect(result).toBeUndefined();
141
+ });
142
+ });
143
+
144
+ describe('queryAll', () => {
145
+ it('should return all topics', async () => {
146
+ // 创建测试数据
147
+ await serverDB.insert(topics).values([
148
+ { id: 'topic1', sessionId, userId },
149
+ { id: 'topic2', sessionId, userId },
150
+ ]);
151
+
152
+ // 调用 queryAll 方法
153
+ const result = await topicModel.queryAll();
154
+
155
+ // 断言返回所有的 topics
156
+ expect(result).toHaveLength(2);
157
+ expect(result[0].id).toBe('topic1');
158
+ expect(result[1].id).toBe('topic2');
159
+ });
160
+ });
161
+
162
+ describe('queryByKeyword', () => {
163
+ it('should return topics matching topic title keyword', async () => {
164
+ // 创建测试数据
165
+ await serverDB.transaction(async (tx) => {
166
+ await tx.insert(topics).values([
167
+ { id: 'topic1', title: 'Hello world', sessionId, userId },
168
+ { id: 'topic2', title: 'Goodbye', sessionId, userId },
169
+ ]);
170
+ await tx
171
+ .insert(messages)
172
+ .values([
173
+ { id: 'message1', role: 'assistant', content: 'abc there', topicId: 'topic1', userId },
174
+ ]);
175
+ });
176
+ // 调用 queryByKeyword 方法
177
+ const result = await topicModel.queryByKeyword('hello', sessionId);
178
+
179
+ // 断言返回匹配关键字的 topic
180
+ expect(result).toHaveLength(1);
181
+ expect(result[0].id).toBe('topic1');
182
+ });
183
+
184
+ it('should return topics matching message content keyword', async () => {
185
+ // 创建测试数据
186
+ await serverDB.transaction(async (tx) => {
187
+ await tx.insert(topics).values([
188
+ { id: 'topic1', title: 'abc world', sessionId, userId },
189
+ { id: 'topic2', title: 'Goodbye', sessionId, userId },
190
+ ]);
191
+ await tx.insert(messages).values([
192
+ {
193
+ id: 'message1',
194
+ role: 'assistant',
195
+ content: 'Hello there',
196
+ topicId: 'topic1',
197
+ userId,
198
+ },
199
+ ]);
200
+ });
201
+ // 调用 queryByKeyword 方法
202
+ const result = await topicModel.queryByKeyword('hello', sessionId);
203
+
204
+ // 断言返回匹配关键字的 topic
205
+ expect(result).toHaveLength(1);
206
+ expect(result[0].id).toBe('topic1');
207
+ });
208
+
209
+ it('should return nothing if not match', async () => {
210
+ // 创建测试数据
211
+ await serverDB.insert(topics).values([
212
+ { id: 'topic1', title: 'Hello world', userId },
213
+ { id: 'topic2', title: 'Goodbye', sessionId, userId },
214
+ ]);
215
+ await serverDB
216
+ .insert(messages)
217
+ .values([
218
+ { id: 'message1', role: 'assistant', content: 'abc there', topicId: 'topic1', userId },
219
+ ]);
220
+
221
+ // 调用 queryByKeyword 方法
222
+ const result = await topicModel.queryByKeyword('hello', sessionId);
223
+
224
+ // 断言返回匹配关键字的 topic
225
+ expect(result).toHaveLength(0);
226
+ });
227
+ });
228
+
229
+ describe('count', () => {
230
+ it('should return total number of topics', async () => {
231
+ // 创建测试数据
232
+ await serverDB.insert(topics).values([
233
+ { id: 'abc_topic1', sessionId, userId },
234
+ { id: 'abc_topic2', sessionId, userId },
235
+ ]);
236
+
237
+ // 调用 count 方法
238
+ const result = await topicModel.count();
239
+
240
+ // 断言返回 topics 总数
241
+ expect(result).toBe(2);
242
+ });
243
+ });
244
+
245
+ describe('delete', () => {
246
+ it('should delete a topic and its associated messages', async () => {
247
+ const topicId = 'topic1';
248
+ await serverDB.transaction(async (tx) => {
249
+ await tx.insert(users).values({ id: '345' });
250
+ await tx.insert(sessions).values([
251
+ { id: 'session1', userId },
252
+ { id: 'session2', userId: '345' },
253
+ ]);
254
+ await tx.insert(topics).values([
255
+ { id: topicId, sessionId: 'session1', userId },
256
+ { id: 'topic2', sessionId: 'session2', userId: '345' },
257
+ ]);
258
+ await tx.insert(messages).values([
259
+ { id: 'message1', role: 'user', topicId: topicId, userId },
260
+ { id: 'message2', role: 'assistant', topicId: topicId, userId },
261
+ { id: 'message3', role: 'user', topicId: 'topic2', userId: '345' },
262
+ ]);
263
+ });
264
+
265
+ // 调用 delete 方法
266
+ await topicModel.delete(topicId);
267
+
268
+ // 断言 topic 和关联的 messages 都被删除了
269
+ expect(
270
+ await serverDB.select().from(messages).where(eq(messages.topicId, topicId)),
271
+ ).toHaveLength(0);
272
+ expect(await serverDB.select().from(topics)).toHaveLength(1);
273
+
274
+ expect(await serverDB.select().from(messages)).toHaveLength(1);
275
+ });
276
+ });
277
+
278
+ describe('batchDeleteBySessionId', () => {
279
+ it('should delete all topics associated with a session', async () => {
280
+ await serverDB.insert(sessions).values([
281
+ { id: 'session1', userId },
282
+ { id: 'session2', userId },
283
+ ]);
284
+ await serverDB.insert(topics).values([
285
+ { id: 'topic1', sessionId: 'session1', userId },
286
+ { id: 'topic2', sessionId: 'session1', userId },
287
+ { id: 'topic3', sessionId: 'session2', userId },
288
+ { id: 'topic4', userId },
289
+ ]);
290
+
291
+ // 调用 batchDeleteBySessionId 方法
292
+ await topicModel.batchDeleteBySessionId('session1');
293
+
294
+ // 断言属于 session1 的 topics 都被删除了
295
+ expect(
296
+ await serverDB.select().from(topics).where(eq(topics.sessionId, 'session1')),
297
+ ).toHaveLength(0);
298
+ expect(await serverDB.select().from(topics)).toHaveLength(2);
299
+ });
300
+ it('should delete all topics associated without sessionId', async () => {
301
+ await serverDB.insert(sessions).values([{ id: 'session1', userId }]);
302
+
303
+ await serverDB.insert(topics).values([
304
+ { id: 'topic1', sessionId: 'session1', userId },
305
+ { id: 'topic2', sessionId: 'session1', userId },
306
+ { id: 'topic4', userId },
307
+ ]);
308
+
309
+ // 调用 batchDeleteBySessionId 方法
310
+ await topicModel.batchDeleteBySessionId();
311
+
312
+ // 断言属于 session1 的 topics 都被删除了
313
+ expect(
314
+ await serverDB.select().from(topics).where(eq(topics.sessionId, 'session1')),
315
+ ).toHaveLength(2);
316
+ expect(await serverDB.select().from(topics)).toHaveLength(2);
317
+ });
318
+ });
319
+
320
+ describe('batchDelete', () => {
321
+ it('should delete multiple topics and their associated messages', async () => {
322
+ await serverDB.transaction(async (tx) => {
323
+ await tx.insert(sessions).values({ id: 'session1', userId });
324
+ await tx.insert(topics).values([
325
+ { id: 'topic1', sessionId: 'session1', userId },
326
+ { id: 'topic2', sessionId: 'session1', userId },
327
+ { id: 'topic3', sessionId: 'session1', userId },
328
+ ]);
329
+ await tx.insert(messages).values([
330
+ { id: 'message1', role: 'user', topicId: 'topic1', userId },
331
+ { id: 'message2', role: 'assistant', topicId: 'topic2', userId },
332
+ { id: 'message3', role: 'user', topicId: 'topic3', userId },
333
+ ]);
334
+ });
335
+
336
+ // 调用 batchDelete 方法
337
+ await topicModel.batchDelete(['topic1', 'topic2']);
338
+
339
+ // 断言指定的 topics 和关联的 messages 都被删除了
340
+ expect(await serverDB.select().from(topics)).toHaveLength(1);
341
+ expect(await serverDB.select().from(messages)).toHaveLength(1);
342
+ });
343
+ });
344
+
345
+ describe('deleteAll', () => {
346
+ it('should delete all topics of the user', async () => {
347
+ await serverDB.insert(users).values({ id: '345' });
348
+ await serverDB.insert(sessions).values([
349
+ { id: 'session1', userId },
350
+ { id: 'session2', userId: '345' },
351
+ ]);
352
+ await serverDB.insert(topics).values([
353
+ { id: 'topic1', sessionId: 'session1', userId },
354
+ { id: 'topic2', sessionId: 'session1', userId },
355
+ { id: 'topic3', sessionId: 'session2', userId: '345' },
356
+ ]);
357
+
358
+ // 调用 deleteAll 方法
359
+ await topicModel.deleteAll();
360
+
361
+ // 断言当前用户的所有 topics 都被删除了
362
+ expect(await serverDB.select().from(topics).where(eq(topics.userId, userId))).toHaveLength(0);
363
+ expect(await serverDB.select().from(topics)).toHaveLength(1);
364
+ });
365
+ });
366
+
367
+ describe('update', () => {
368
+ it('should update a topic', async () => {
369
+ // 创建一个测试 session
370
+ const topicId = '123';
371
+ await serverDB.insert(topics).values({ userId, id: topicId, title: 'Test', favorite: true });
372
+
373
+ // 调用 update 方法更新 session
374
+ const item = await topicModel.update(topicId, {
375
+ title: 'Updated Test',
376
+ favorite: false,
377
+ });
378
+
379
+ // 断言更新后的结果
380
+ expect(item).toHaveLength(1);
381
+ expect(item[0].title).toBe('Updated Test');
382
+ expect(item[0].favorite).toBeFalsy();
383
+ });
384
+
385
+ it('should not update a topic if user ID does not match', async () => {
386
+ // 创建一个测试 topic, 但使用不同的 user ID
387
+ await serverDB.insert(users).values([{ id: '456' }]);
388
+ const topicId = '123';
389
+ await serverDB
390
+ .insert(topics)
391
+ .values({ userId: '456', id: topicId, title: 'Test', favorite: true });
392
+
393
+ // 尝试更新这个 topic , 应该不会有任何更新
394
+ const item = await topicModel.update(topicId, {
395
+ title: 'Updated Test Session',
396
+ });
397
+
398
+ expect(item).toHaveLength(0);
399
+ });
400
+ });
401
+
402
+ describe('create', () => {
403
+ it('should create a new topic and associate messages', async () => {
404
+ const topicData = {
405
+ title: 'New Topic',
406
+ favorite: true,
407
+ sessionId,
408
+ messages: ['message1', 'message2'],
409
+ } satisfies CreateTopicParams;
410
+
411
+ const topicId = 'new-topic';
412
+
413
+ // 预先创建一些 messages
414
+ await serverDB.insert(messages).values([
415
+ { id: 'message1', role: 'user', userId, sessionId },
416
+ { id: 'message2', role: 'assistant', userId, sessionId },
417
+ { id: 'message3', role: 'user', userId, sessionId },
418
+ ]);
419
+
420
+ // 调用 create 方法
421
+ const createdTopic = await topicModel.create(topicData, topicId);
422
+
423
+ // 断言返回的 topic 数据正确
424
+ expect(createdTopic).toEqual({
425
+ id: topicId,
426
+ title: 'New Topic',
427
+ favorite: true,
428
+ sessionId,
429
+ userId,
430
+ clientId: null,
431
+ createdAt: expect.any(Date),
432
+ updatedAt: expect.any(Date),
433
+ });
434
+
435
+ // 断言 topic 已在数据库中创建
436
+ const dbTopic = await serverDB.select().from(topics).where(eq(topics.id, topicId));
437
+ expect(dbTopic).toHaveLength(1);
438
+ expect(dbTopic[0]).toEqual(createdTopic);
439
+
440
+ // 断言关联的 messages 的 topicId 已更新
441
+ const associatedMessages = await serverDB
442
+ .select()
443
+ .from(messages)
444
+ .where(inArray(messages.id, topicData.messages!));
445
+ expect(associatedMessages).toHaveLength(2);
446
+ expect(associatedMessages.every((msg) => msg.topicId === topicId)).toBe(true);
447
+
448
+ // 断言未关联的 message 的 topicId 没有更新
449
+ const unassociatedMessage = await serverDB
450
+ .select()
451
+ .from(messages)
452
+ .where(eq(messages.id, 'message3'));
453
+
454
+ expect(unassociatedMessage[0].topicId).toBeNull();
455
+ });
456
+
457
+ it('should create a new topic without associating messages', async () => {
458
+ const topicData = {
459
+ title: 'New Topic',
460
+ favorite: false,
461
+ sessionId,
462
+ };
463
+
464
+ const topicId = 'new-topic';
465
+
466
+ // 调用 create 方法
467
+ const createdTopic = await topicModel.create(topicData, topicId);
468
+
469
+ // 断言返回的 topic 数据正确
470
+ expect(createdTopic).toEqual({
471
+ id: topicId,
472
+ title: 'New Topic',
473
+ favorite: false,
474
+ clientId: null,
475
+ sessionId,
476
+ userId,
477
+ createdAt: expect.any(Date),
478
+ updatedAt: expect.any(Date),
479
+ });
480
+
481
+ // 断言 topic 已在数据库中创建
482
+ const dbTopic = await serverDB.select().from(topics).where(eq(topics.id, topicId));
483
+ expect(dbTopic).toHaveLength(1);
484
+ expect(dbTopic[0]).toEqual(createdTopic);
485
+ });
486
+ });
487
+
488
+ describe('batchCreate', () => {
489
+ it('should batch create topics and update associated messages', async () => {
490
+ // 准备测试数据
491
+ const topicParams = [
492
+ {
493
+ title: 'Topic 1',
494
+ favorite: true,
495
+ sessionId,
496
+ messages: ['message1', 'message2'],
497
+ },
498
+ {
499
+ title: 'Topic 2',
500
+ favorite: false,
501
+ sessionId,
502
+ messages: ['message3'],
503
+ },
504
+ ];
505
+ await serverDB.insert(messages).values([
506
+ { id: 'message1', role: 'user', userId },
507
+ { id: 'message2', role: 'assistant', userId },
508
+ { id: 'message3', role: 'user', userId },
509
+ ]);
510
+
511
+ // 调用 batchCreate 方法
512
+ const createdTopics = await topicModel.batchCreate(topicParams);
513
+
514
+ // 断言返回的 topics 数据正确
515
+ expect(createdTopics).toHaveLength(2);
516
+ expect(createdTopics[0]).toMatchObject({
517
+ title: 'Topic 1',
518
+ favorite: true,
519
+ sessionId,
520
+ userId,
521
+ });
522
+ expect(createdTopics[1]).toMatchObject({
523
+ title: 'Topic 2',
524
+ favorite: false,
525
+ sessionId,
526
+ userId,
527
+ });
528
+
529
+ // 断言 topics 表中的数据正确
530
+ const items = await serverDB.select().from(topics);
531
+ expect(items).toHaveLength(2);
532
+ expect(items[0]).toMatchObject({
533
+ title: 'Topic 1',
534
+ favorite: true,
535
+ sessionId,
536
+ userId,
537
+ });
538
+ expect(items[1]).toMatchObject({
539
+ title: 'Topic 2',
540
+ favorite: false,
541
+ sessionId,
542
+ userId,
543
+ });
544
+
545
+ // 断言关联的 messages 的 topicId 被正确更新
546
+ const updatedMessages = await serverDB.select().from(messages);
547
+ expect(updatedMessages).toHaveLength(3);
548
+ expect(updatedMessages[0].topicId).toBe(createdTopics[0].id);
549
+ expect(updatedMessages[1].topicId).toBe(createdTopics[0].id);
550
+ expect(updatedMessages[2].topicId).toBe(createdTopics[1].id);
551
+ });
552
+
553
+ it('should generate topic IDs if not provided', async () => {
554
+ // 准备测试数据
555
+ const topicParams = [
556
+ {
557
+ title: 'Topic 1',
558
+ favorite: true,
559
+ sessionId,
560
+ },
561
+ {
562
+ title: 'Topic 2',
563
+ favorite: false,
564
+ sessionId,
565
+ },
566
+ ];
567
+
568
+ // 调用 batchCreate 方法
569
+ const createdTopics = await topicModel.batchCreate(topicParams);
570
+
571
+ // 断言生成了正确的 topic ID
572
+ expect(createdTopics[0].id).toBeDefined();
573
+ expect(createdTopics[1].id).toBeDefined();
574
+ expect(createdTopics[0].id).not.toBe(createdTopics[1].id);
575
+ });
576
+ });
577
+
578
+ describe('duplicate', () => {
579
+ it('should duplicate a topic and its associated messages', async () => {
580
+ const topicId = 'topic-duplicate';
581
+ const newTitle = 'Duplicated Topic';
582
+
583
+ // 创建原始的 topic 和 messages
584
+ await serverDB.transaction(async (tx) => {
585
+ await tx.insert(topics).values({ id: topicId, sessionId, userId, title: 'Original Topic' });
586
+ await tx.insert(messages).values([
587
+ { id: 'message1', role: 'user', topicId, userId, content: 'User message' },
588
+ { id: 'message2', role: 'assistant', topicId, userId, content: 'Assistant message' },
589
+ ]);
590
+ });
591
+
592
+ // 调用 duplicate 方法
593
+ const { topic: duplicatedTopic, messages: duplicatedMessages } = await topicModel.duplicate(
594
+ topicId,
595
+ newTitle,
596
+ );
597
+
598
+ // 断言复制的 topic 的属性正确
599
+ expect(duplicatedTopic.id).not.toBe(topicId);
600
+ expect(duplicatedTopic.title).toBe(newTitle);
601
+ expect(duplicatedTopic.sessionId).toBe(sessionId);
602
+ expect(duplicatedTopic.userId).toBe(userId);
603
+
604
+ // 断言复制的 messages 的属性正确
605
+ expect(duplicatedMessages).toHaveLength(2);
606
+ expect(duplicatedMessages[0].id).not.toBe('message1');
607
+ expect(duplicatedMessages[0].topicId).toBe(duplicatedTopic.id);
608
+ expect(duplicatedMessages[0].content).toBe('User message');
609
+ expect(duplicatedMessages[1].id).not.toBe('message2');
610
+ expect(duplicatedMessages[1].topicId).toBe(duplicatedTopic.id);
611
+ expect(duplicatedMessages[1].content).toBe('Assistant message');
612
+ });
613
+
614
+ it('should throw an error if the topic to duplicate does not exist', async () => {
615
+ const topicId = 'nonexistent-topic';
616
+
617
+ // 调用 duplicate 方法,期望抛出错误
618
+ await expect(topicModel.duplicate(topicId)).rejects.toThrow(
619
+ `Topic with id ${topicId} not found`,
620
+ );
621
+ });
622
+ });
623
+ });