@lobehub/chat 1.71.5 → 1.72.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/CHANGELOG.md +25 -0
  2. package/changelog/v1.json +9 -0
  3. package/docs/developer/database-schema.dbml +16 -0
  4. package/package.json +3 -3
  5. package/src/database/client/db.ts +14 -8
  6. package/src/database/client/migrations.json +62 -0
  7. package/src/database/migrations/0017_add_user_id_to_tables.sql +225 -0
  8. package/src/database/migrations/meta/0017_snapshot.json +3858 -0
  9. package/src/database/migrations/meta/_journal.json +7 -0
  10. package/src/database/{server/models → models}/__tests__/_test_template.ts +2 -2
  11. package/src/database/models/__tests__/_util.ts +12 -0
  12. package/src/database/{server/models → models}/__tests__/agent.test.ts +6 -5
  13. package/src/database/{server/models → models}/__tests__/aiModel.test.ts +5 -4
  14. package/src/database/{server/models → models}/__tests__/aiProvider.test.ts +5 -4
  15. package/src/database/{server/models → models}/__tests__/asyncTask.test.ts +5 -4
  16. package/src/database/{server/models → models}/__tests__/chunk.test.ts +25 -21
  17. package/src/database/{server/models → models}/__tests__/file.test.ts +19 -5
  18. package/src/database/{server/models → models}/__tests__/knowledgeBase.test.ts +9 -4
  19. package/src/database/{server/models → models}/__tests__/message.test.ts +625 -29
  20. package/src/database/{server/models → models}/__tests__/plugin.test.ts +5 -4
  21. package/src/database/{server/models → models}/__tests__/session.test.ts +23 -20
  22. package/src/database/{server/models → models}/__tests__/sessionGroup.test.ts +5 -4
  23. package/src/database/{server/models → models}/__tests__/topic.test.ts +5 -4
  24. package/src/database/repositories/dataImporter/index.ts +3 -0
  25. package/src/database/schemas/file.ts +38 -32
  26. package/src/database/schemas/message.ts +21 -0
  27. package/src/database/schemas/relations.ts +10 -0
  28. package/src/database/server/models/__tests__/nextauth.test.ts +2 -0
  29. package/src/database/server/models/__tests__/user.test.ts +13 -1
  30. package/src/database/server/models/chunk.ts +5 -1
  31. package/src/database/server/models/file.ts +6 -3
  32. package/src/database/server/models/message.ts +29 -12
  33. package/src/database/server/models/session.ts +1 -0
  34. package/src/services/file/client.test.ts +2 -1
  35. package/src/services/message/client.test.ts +3 -3
  36. package/src/services/session/client.test.ts +5 -3
  37. package/src/types/message/base.ts +7 -0
  38. package/vitest.server.config.ts +1 -1
  39. package/src/database/server/models/user.test.ts +0 -58
  40. /package/src/database/{server/models → models}/__tests__/fixtures/embedding.ts +0 -0
@@ -2,7 +2,8 @@ import dayjs from 'dayjs';
2
2
  import { eq } from 'drizzle-orm/expressions';
3
3
  import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
4
4
 
5
- import { getTestDBInstance } from '@/database/server/core/dbForTest';
5
+ import { getTestDB } from '@/database/models/__tests__/_util';
6
+ import { LobeChatDatabase } from '@/database/type';
6
7
  import { MessageItem } from '@/types/message';
7
8
  import { uuid } from '@/utils/uuid';
8
9
 
@@ -21,15 +22,15 @@ import {
21
22
  sessions,
22
23
  topics,
23
24
  users,
24
- } from '../../../schemas';
25
- import { MessageModel } from '../message';
25
+ } from '../../schemas';
26
+ import { MessageModel } from '../../server/models/message';
26
27
  import { codeEmbedding } from './fixtures/embedding';
27
28
 
28
- let serverDB = await getTestDBInstance();
29
+ const serverDB: LobeChatDatabase = await getTestDB();
29
30
 
30
31
  const userId = 'message-db';
31
32
  const messageModel = new MessageModel(serverDB, userId);
32
-
33
+ const embeddingsId = uuid();
33
34
  beforeEach(async () => {
34
35
  // 在每个测试用例之前,清空表
35
36
  await serverDB.transaction(async (trx) => {
@@ -49,6 +50,12 @@ beforeEach(async () => {
49
50
  fileType: 'image/png',
50
51
  size: 1000,
51
52
  });
53
+
54
+ await trx.insert(embeddings).values({
55
+ id: embeddingsId,
56
+ embeddings: codeEmbedding,
57
+ userId,
58
+ });
52
59
  });
53
60
  });
54
61
 
@@ -201,13 +208,14 @@ describe('MessageModel', () => {
201
208
  { id: 'f-1', url: 'abc', name: 'file-1', userId, fileType: 'image/png', size: 100 },
202
209
  { id: 'f-3', url: 'abc', name: 'file-3', userId, fileType: 'image/png', size: 400 },
203
210
  ]);
204
- await trx
205
- .insert(messageTTS)
206
- .values([{ id: '1' }, { id: '2', voice: 'a', fileId: 'f-1', contentMd5: 'abc' }]);
211
+ await trx.insert(messageTTS).values([
212
+ { id: '1', userId },
213
+ { id: '2', voice: 'a', fileId: 'f-1', contentMd5: 'abc', userId },
214
+ ]);
207
215
 
208
216
  await trx.insert(messagesFiles).values([
209
- { fileId: 'f-0', messageId: '1' },
210
- { fileId: 'f-3', messageId: '1' },
217
+ { fileId: 'f-0', messageId: '1', userId },
218
+ { fileId: 'f-3', messageId: '1', userId },
211
219
  ]);
212
220
  });
213
221
 
@@ -244,10 +252,10 @@ describe('MessageModel', () => {
244
252
  ]);
245
253
  await trx
246
254
  .insert(messageTranslates)
247
- .values([{ id: '1', content: 'translated', from: 'en', to: 'zh' }]);
255
+ .values([{ id: '1', content: 'translated', from: 'en', to: 'zh', userId }]);
248
256
  await trx
249
257
  .insert(messageTTS)
250
- .values([{ id: '1', voice: 'voice1', fileId: 'f1', contentMd5: 'md5' }]);
258
+ .values([{ id: '1', voice: 'voice1', fileId: 'f1', contentMd5: 'md5', userId }]);
251
259
  });
252
260
 
253
261
  // 调用 query 方法
@@ -281,7 +289,81 @@ describe('MessageModel', () => {
281
289
  expect(result3).toHaveLength(0);
282
290
  });
283
291
 
284
- // 补充测试复杂查询场景
292
+ describe('query with messageQueries', () => {
293
+ it('should include ragQuery, ragQueryId and ragRawQuery in query results', async () => {
294
+ // 创建测试数据
295
+ const messageId = 'msg-with-query';
296
+ const queryId = uuid();
297
+
298
+ await serverDB.insert(messages).values({
299
+ id: messageId,
300
+ userId,
301
+ role: 'user',
302
+ content: 'test message',
303
+ });
304
+
305
+ await serverDB.insert(messageQueries).values({
306
+ id: queryId,
307
+ messageId,
308
+ userQuery: 'original query',
309
+ rewriteQuery: 'rewritten query',
310
+ userId,
311
+ });
312
+
313
+ // 调用 query 方法
314
+ const result = await messageModel.query();
315
+
316
+ // 断言结果
317
+ expect(result).toHaveLength(1);
318
+ expect(result[0].id).toBe(messageId);
319
+ expect(result[0].ragQueryId).toBe(queryId);
320
+ expect(result[0].ragQuery).toBe('rewritten query');
321
+ expect(result[0].ragRawQuery).toBe('original query');
322
+ });
323
+
324
+ it.skip('should handle multiple message queries for the same message', async () => {
325
+ // 创建测试数据
326
+ const messageId = 'msg-multi-query';
327
+ const queryId1 = uuid();
328
+ const queryId2 = uuid();
329
+
330
+ await serverDB.insert(messages).values({
331
+ id: messageId,
332
+ userId,
333
+ role: 'user',
334
+ content: 'test message',
335
+ });
336
+
337
+ // 创建两个查询,但查询结果应该只包含一个(最新的)
338
+ await serverDB.insert(messageQueries).values([
339
+ {
340
+ id: queryId1,
341
+ messageId,
342
+ userQuery: 'original query 1',
343
+ rewriteQuery: 'rewritten query 1',
344
+ userId,
345
+ },
346
+ {
347
+ id: queryId2,
348
+ messageId,
349
+ userQuery: 'original query 2',
350
+ rewriteQuery: 'rewritten query 2',
351
+ userId,
352
+ },
353
+ ]);
354
+
355
+ // 调用 query 方法
356
+ const result = await messageModel.query();
357
+
358
+ // 断言结果 - 应该只包含最新的查询
359
+ expect(result).toHaveLength(1);
360
+ expect(result[0].id).toBe(messageId);
361
+ expect(result[0].ragQueryId).toBe(queryId2);
362
+ expect(result[0].ragQuery).toBe('rewritten query 2');
363
+ expect(result[0].ragRawQuery).toBe('original query 2');
364
+ });
365
+ });
366
+
285
367
  it('should handle complex query with multiple joins and file chunks', async () => {
286
368
  await serverDB.transaction(async (trx) => {
287
369
  const chunk1Id = uuid();
@@ -316,12 +398,14 @@ describe('MessageModel', () => {
316
398
  // 关联消息和文件
317
399
  await trx.insert(messagesFiles).values({
318
400
  messageId: 'msg1',
401
+ userId,
319
402
  fileId: 'file1',
320
403
  });
321
404
 
322
405
  // 创建文件块关联
323
406
  await trx.insert(fileChunks).values({
324
407
  fileId: 'file1',
408
+ userId,
325
409
  chunkId: chunk1Id,
326
410
  });
327
411
 
@@ -329,6 +413,7 @@ describe('MessageModel', () => {
329
413
  await trx.insert(messageQueries).values({
330
414
  id: query1Id,
331
415
  messageId: 'msg1',
416
+ userId,
332
417
  userQuery: 'original query',
333
418
  rewriteQuery: 'rewritten query',
334
419
  });
@@ -339,6 +424,7 @@ describe('MessageModel', () => {
339
424
  queryId: query1Id,
340
425
  chunkId: chunk1Id,
341
426
  similarity: '0.95',
427
+ userId,
342
428
  });
343
429
  });
344
430
 
@@ -648,6 +734,135 @@ describe('MessageModel', () => {
648
734
  expect(pluginResult[0].identifier).toBe('lobe-web-browsing');
649
735
  expect(pluginResult[0].state!).toMatchObject(state);
650
736
  });
737
+
738
+ describe('create with advanced parameters', () => {
739
+ it('should create a message with custom ID', async () => {
740
+ const customId = 'custom-msg-id';
741
+
742
+ const result = await messageModel.create(
743
+ {
744
+ role: 'user',
745
+ content: 'message with custom ID',
746
+ sessionId: '1',
747
+ },
748
+ customId,
749
+ );
750
+
751
+ expect(result.id).toBe(customId);
752
+
753
+ // 验证数据库中的记录
754
+ const dbResult = await serverDB.select().from(messages).where(eq(messages.id, customId));
755
+ expect(dbResult).toHaveLength(1);
756
+ expect(dbResult[0].id).toBe(customId);
757
+ });
758
+
759
+ it.skip('should create a message with file chunks and RAG query ID', async () => {
760
+ // 创建测试数据
761
+ const chunkId1 = uuid();
762
+ const chunkId2 = uuid();
763
+ const ragQueryId = uuid();
764
+
765
+ await serverDB.insert(chunks).values([
766
+ { id: chunkId1, text: 'chunk text 1' },
767
+ { id: chunkId2, text: 'chunk text 2' },
768
+ ]);
769
+
770
+ // 调用 create 方法
771
+ const result = await messageModel.create({
772
+ role: 'assistant',
773
+ content: 'message with file chunks',
774
+ fileChunks: [
775
+ { id: chunkId1, similarity: 0.95 },
776
+ { id: chunkId2, similarity: 0.85 },
777
+ ],
778
+ ragQueryId,
779
+ sessionId: '1',
780
+ });
781
+
782
+ // 验证消息创建成功
783
+ expect(result.id).toBeDefined();
784
+
785
+ // 验证消息查询块关联创建成功
786
+ const queryChunks = await serverDB
787
+ .select()
788
+ .from(messageQueryChunks)
789
+ .where(eq(messageQueryChunks.messageId, result.id));
790
+
791
+ expect(queryChunks).toHaveLength(2);
792
+ expect(queryChunks[0].chunkId).toBe(chunkId1);
793
+ expect(queryChunks[0].queryId).toBe(ragQueryId);
794
+ expect(queryChunks[0].similarity).toBe('0.95');
795
+ expect(queryChunks[1].chunkId).toBe(chunkId2);
796
+ expect(queryChunks[1].similarity).toBe('0.85');
797
+ });
798
+
799
+ it('should create a message with files', async () => {
800
+ // 创建测试数据
801
+ await serverDB.insert(files).values([
802
+ {
803
+ id: 'file1',
804
+ name: 'file1.txt',
805
+ fileType: 'text/plain',
806
+ size: 100,
807
+ url: 'url1',
808
+ userId,
809
+ },
810
+ {
811
+ id: 'file2',
812
+ name: 'file2.jpg',
813
+ fileType: 'image/jpeg',
814
+ size: 200,
815
+ url: 'url2',
816
+ userId,
817
+ },
818
+ ]);
819
+
820
+ // 调用 create 方法
821
+ const result = await messageModel.create({
822
+ role: 'user',
823
+ content: 'message with files',
824
+ files: ['file1', 'file2'],
825
+ sessionId: '1',
826
+ });
827
+
828
+ // 验证消息创建成功
829
+ expect(result.id).toBeDefined();
830
+
831
+ // 验证消息文件关联创建成功
832
+ const messageFiles = await serverDB
833
+ .select()
834
+ .from(messagesFiles)
835
+ .where(eq(messagesFiles.messageId, result.id));
836
+
837
+ expect(messageFiles).toHaveLength(2);
838
+ expect(messageFiles[0].fileId).toBe('file1');
839
+ expect(messageFiles[1].fileId).toBe('file2');
840
+ });
841
+
842
+ it('should create a message with custom timestamps', async () => {
843
+ const customCreatedAt = '2022-05-15T10:30:00Z';
844
+ const customUpdatedAt = '2022-05-16T11:45:00Z';
845
+
846
+ const result = await messageModel.create({
847
+ role: 'user',
848
+ content: 'message with custom timestamps',
849
+ createdAt: customCreatedAt as any,
850
+ updatedAt: customUpdatedAt as any,
851
+ sessionId: '1',
852
+ });
853
+
854
+ // 验证数据库中的记录
855
+ const dbResult = await serverDB.select().from(messages).where(eq(messages.id, result.id));
856
+
857
+ // 日期比较需要考虑时区和格式化问题,所以使用 toISOString 进行比较
858
+ expect(new Date(dbResult[0].createdAt!).toISOString()).toBe(
859
+ new Date(customCreatedAt).toISOString(),
860
+ );
861
+ expect(new Date(dbResult[0].updatedAt!).toISOString()).toBe(
862
+ new Date(customUpdatedAt).toISOString(),
863
+ );
864
+ });
865
+ });
651
866
  });
652
867
 
653
868
  describe('batchCreateMessages', () => {
@@ -734,10 +949,122 @@ describe('MessageModel', () => {
734
949
 
735
950
  // 断言结果
736
951
  const result = await serverDB.select().from(messages).where(eq(messages.id, '1'));
737
- expect(result[0].tools[0].arguments).toBe(
952
+ expect((result[0].tools as any)[0].arguments).toBe(
738
953
  '{"query":"2024 杭州暴雨","searchEngines":["duckduckgo","google","brave"]}',
739
954
  );
740
955
  });
956
+
957
+ describe('update with imageList', () => {
958
+ it('should update a message and add image files', async () => {
959
+ // 创建测试数据
960
+ await serverDB.insert(messages).values({
961
+ id: 'msg-to-update',
962
+ userId,
963
+ role: 'user',
964
+ content: 'original content',
965
+ });
966
+
967
+ await serverDB.insert(files).values([
968
+ {
969
+ id: 'img1',
970
+ name: 'image1.jpg',
971
+ fileType: 'image/jpeg',
972
+ size: 100,
973
+ url: 'url1',
974
+ userId,
975
+ },
976
+ { id: 'img2', name: 'image2.png', fileType: 'image/png', size: 200, url: 'url2', userId },
977
+ ]);
978
+
979
+ // 调用 update 方法
980
+ await messageModel.update('msg-to-update', {
981
+ content: 'updated content',
982
+ imageList: [
983
+ { id: 'img1', alt: 'image 1', url: 'url1' },
984
+ { id: 'img2', alt: 'image 2', url: 'url2' },
985
+ ],
986
+ });
987
+
988
+ // 验证消息更新成功
989
+ const updatedMessage = await serverDB
990
+ .select()
991
+ .from(messages)
992
+ .where(eq(messages.id, 'msg-to-update'));
993
+
994
+ expect(updatedMessage[0].content).toBe('updated content');
995
+
996
+ // 验证消息文件关联创建成功
997
+ const messageFiles = await serverDB
998
+ .select()
999
+ .from(messagesFiles)
1000
+ .where(eq(messagesFiles.messageId, 'msg-to-update'));
1001
+
1002
+ expect(messageFiles).toHaveLength(2);
1003
+ expect(messageFiles[0].fileId).toBe('img1');
1004
+ expect(messageFiles[1].fileId).toBe('img2');
1005
+ });
1006
+
1007
+ it('should handle empty imageList', async () => {
1008
+ // 创建测试数据
1009
+ await serverDB.insert(messages).values({
1010
+ id: 'msg-no-images',
1011
+ userId,
1012
+ role: 'user',
1013
+ content: 'original content',
1014
+ });
1015
+
1016
+ // 调用 update 方法,不提供 imageList
1017
+ await messageModel.update('msg-no-images', {
1018
+ content: 'updated content',
1019
+ });
1020
+
1021
+ // 验证消息更新成功
1022
+ const updatedMessage = await serverDB
1023
+ .select()
1024
+ .from(messages)
1025
+ .where(eq(messages.id, 'msg-no-images'));
1026
+
1027
+ expect(updatedMessage[0].content).toBe('updated content');
1028
+
1029
+ // 验证没有创建消息文件关联
1030
+ const messageFiles = await serverDB
1031
+ .select()
1032
+ .from(messagesFiles)
1033
+ .where(eq(messagesFiles.messageId, 'msg-no-images'));
1034
+
1035
+ expect(messageFiles).toHaveLength(0);
1036
+ });
1037
+
1038
+ it('should update multiple fields at once', async () => {
1039
+ // 创建测试数据
1040
+ await serverDB.insert(messages).values({
1041
+ id: 'msg-multi-update',
1042
+ userId,
1043
+ role: 'user',
1044
+ content: 'original content',
1045
+ model: 'gpt-3.5',
1046
+ });
1047
+
1048
+ // 调用 update 方法,更新多个字段
1049
+ await messageModel.update('msg-multi-update', {
1050
+ content: 'updated content',
1051
+ role: 'assistant',
1052
+ model: 'gpt-4',
1053
+ metadata: { tps: 1 },
1054
+ });
1055
+
1056
+ // 验证消息更新成功
1057
+ const updatedMessage = await serverDB
1058
+ .select()
1059
+ .from(messages)
1060
+ .where(eq(messages.id, 'msg-multi-update'));
1061
+
1062
+ expect(updatedMessage[0].content).toBe('updated content');
1063
+ expect(updatedMessage[0].role).toBe('assistant');
1064
+ expect(updatedMessage[0].model).toBe('gpt-4');
1065
+ expect(updatedMessage[0].metadata).toEqual({ tps: 1 });
1066
+ });
1067
+ });
741
1068
  });
742
1069
 
743
1070
  describe('deleteMessage', () => {
@@ -764,7 +1091,7 @@ describe('MessageModel', () => {
764
1091
  ]);
765
1092
  await trx
766
1093
  .insert(messagePlugins)
767
- .values([{ id: '2', toolCallId: 'tool1', identifier: 'plugin-1' }]);
1094
+ .values([{ id: '2', toolCallId: 'tool1', identifier: 'plugin-1', userId }]);
768
1095
  });
769
1096
 
770
1097
  // 调用 deleteMessage 方法
@@ -858,11 +1185,15 @@ describe('MessageModel', () => {
858
1185
  it('should update the state field in messagePlugins table', async () => {
859
1186
  // 创建测试数据
860
1187
  await serverDB.insert(messages).values({ id: '1', content: 'abc', role: 'user', userId });
861
- await serverDB
862
- .insert(messagePlugins)
863
- .values([
864
- { id: '1', toolCallId: 'tool1', identifier: 'plugin1', state: { key1: 'value1' } },
865
- ]);
1188
+ await serverDB.insert(messagePlugins).values([
1189
+ {
1190
+ id: '1',
1191
+ toolCallId: 'tool1',
1192
+ identifier: 'plugin1',
1193
+ state: { key1: 'value1' },
1194
+ userId,
1195
+ },
1196
+ ]);
866
1197
 
867
1198
  // 调用 updatePluginState 方法
868
1199
  await messageModel.updatePluginState('1', { key2: 'value2' });
@@ -884,11 +1215,15 @@ describe('MessageModel', () => {
884
1215
  it('should update the state field in messagePlugins table', async () => {
885
1216
  // 创建测试数据
886
1217
  await serverDB.insert(messages).values({ id: '1', content: 'abc', role: 'user', userId });
887
- await serverDB
888
- .insert(messagePlugins)
889
- .values([
890
- { id: '1', toolCallId: 'tool1', identifier: 'plugin1', state: { key1: 'value1' } },
891
- ]);
1218
+ await serverDB.insert(messagePlugins).values([
1219
+ {
1220
+ id: '1',
1221
+ toolCallId: 'tool1',
1222
+ identifier: 'plugin1',
1223
+ state: { key1: 'value1' },
1224
+ userId,
1225
+ },
1226
+ ]);
892
1227
 
893
1228
  // 调用 updatePluginState 方法
894
1229
  await messageModel.updateMessagePlugin('1', { identifier: 'plugin2' });
@@ -939,7 +1274,7 @@ describe('MessageModel', () => {
939
1274
  .values([{ id: '1', userId, role: 'user', content: 'message 1' }]);
940
1275
  await trx
941
1276
  .insert(messageTranslates)
942
- .values([{ id: '1', content: 'translated message 1', from: 'en', to: 'zh' }]);
1277
+ .values([{ id: '1', content: 'translated message 1', from: 'en', to: 'zh', userId }]);
943
1278
  });
944
1279
 
945
1280
  // 调用 updateTranslate 方法
@@ -980,7 +1315,7 @@ describe('MessageModel', () => {
980
1315
  .values([{ id: '1', userId, role: 'user', content: 'message 1' }]);
981
1316
  await trx
982
1317
  .insert(messageTTS)
983
- .values([{ id: '1', contentMd5: 'md5', fileId: 'f1', voice: 'voice1' }]);
1318
+ .values([{ id: '1', contentMd5: 'md5', fileId: 'f1', voice: 'voice1', userId }]);
984
1319
  });
985
1320
 
986
1321
  // 调用 updateTTS 方法
@@ -997,7 +1332,7 @@ describe('MessageModel', () => {
997
1332
  it('should delete the message translate record', async () => {
998
1333
  // 创建测试数据
999
1334
  await serverDB.insert(messages).values([{ id: '1', role: 'abc', userId }]);
1000
- await serverDB.insert(messageTranslates).values([{ id: '1' }]);
1335
+ await serverDB.insert(messageTranslates).values([{ id: '1', userId }]);
1001
1336
 
1002
1337
  // 调用 deleteMessageTranslate 方法
1003
1338
  await messageModel.deleteMessageTranslate('1');
@@ -1016,7 +1351,7 @@ describe('MessageModel', () => {
1016
1351
  it('should delete the message TTS record', async () => {
1017
1352
  // 创建测试数据
1018
1353
  await serverDB.insert(messages).values([{ id: '1', role: 'abc', userId }]);
1019
- await serverDB.insert(messageTTS).values([{ id: '1' }]);
1354
+ await serverDB.insert(messageTTS).values([{ userId, id: '1' }]);
1020
1355
 
1021
1356
  // 调用 deleteMessageTTS 方法
1022
1357
  await messageModel.deleteMessageTTS('1');
@@ -1042,6 +1377,90 @@ describe('MessageModel', () => {
1042
1377
  // 断言结果
1043
1378
  expect(result).toBe(2);
1044
1379
  });
1380
+
1381
+ describe('count with date filters', () => {
1382
+ beforeEach(async () => {
1383
+ // 创建测试数据,包含不同日期的消息
1384
+ await serverDB.insert(messages).values([
1385
+ {
1386
+ id: 'date1',
1387
+ userId,
1388
+ role: 'user',
1389
+ content: 'message 1',
1390
+ createdAt: new Date('2023-01-15'),
1391
+ },
1392
+ {
1393
+ id: 'date2',
1394
+ userId,
1395
+ role: 'user',
1396
+ content: 'message 2',
1397
+ createdAt: new Date('2023-02-15'),
1398
+ },
1399
+ {
1400
+ id: 'date3',
1401
+ userId,
1402
+ role: 'user',
1403
+ content: 'message 3',
1404
+ createdAt: new Date('2023-03-15'),
1405
+ },
1406
+ {
1407
+ id: 'date4',
1408
+ userId,
1409
+ role: 'user',
1410
+ content: 'message 4',
1411
+ createdAt: new Date('2023-04-15'),
1412
+ },
1413
+ ]);
1414
+ });
1415
+
1416
+ it('should count messages with startDate filter', async () => {
1417
+ const result = await messageModel.count({ startDate: '2023-02-01' });
1418
+ expect(result).toBe(3); // 2月15日, 3月15日, 4月15日的消息
1419
+ });
1420
+
1421
+ it('should count messages with endDate filter', async () => {
1422
+ const result = await messageModel.count({ endDate: '2023-03-01' });
1423
+ expect(result).toBe(2); // 1月15日, 2月15日的消息
1424
+ });
1425
+
1426
+ it('should count messages with both startDate and endDate filters', async () => {
1427
+ const result = await messageModel.count({
1428
+ startDate: '2023-02-01',
1429
+ endDate: '2023-03-31',
1430
+ });
1431
+ expect(result).toBe(2); // 2月15日, 3月15日的消息
1432
+ });
1433
+
1434
+ it('should count messages with range filter', async () => {
1435
+ const result = await messageModel.count({
1436
+ range: ['2023-02-01', '2023-04-01'],
1437
+ });
1438
+ expect(result).toBe(2); // 2月15日, 3月15日的消息
1439
+ });
1440
+
1441
+ it('should handle edge cases in date filters', async () => {
1442
+ // 边界日期
1443
+ const result1 = await messageModel.count({
1444
+ startDate: '2023-01-15',
1445
+ endDate: '2023-04-15',
1446
+ });
1447
+ expect(result1).toBe(4); // 包含所有消息
1448
+
1449
+ // 没有消息的日期范围
1450
+ const result2 = await messageModel.count({
1451
+ startDate: '2023-05-01',
1452
+ endDate: '2023-06-01',
1453
+ });
1454
+ expect(result2).toBe(0);
1455
+
1456
+ // 精确到一天
1457
+ const result3 = await messageModel.count({
1458
+ startDate: '2023-01-15',
1459
+ endDate: '2023-01-15',
1460
+ });
1461
+ expect(result3).toBe(1);
1462
+ });
1463
+ });
1045
1464
  });
1046
1465
 
1047
1466
  describe('findMessageQueriesById', () => {
@@ -1068,6 +1487,7 @@ describe('MessageModel', () => {
1068
1487
  userQuery: 'test query',
1069
1488
  rewriteQuery: 'rewritten query',
1070
1489
  embeddingsId: embeddings1Id,
1490
+ userId,
1071
1491
  });
1072
1492
  });
1073
1493
 
@@ -1523,4 +1943,180 @@ describe('MessageModel', () => {
1523
1943
  expect(result3).toBe(true);
1524
1944
  });
1525
1945
  });
1946
+
1947
+ describe('createMessageQuery', () => {
1948
+ it('should create a new message query', async () => {
1949
+ // 创建测试数据
1950
+ await serverDB.insert(messages).values({
1951
+ id: 'msg1',
1952
+ userId,
1953
+ role: 'user',
1954
+ content: 'test message',
1955
+ });
1956
+
1957
+ // 调用 createMessageQuery 方法
1958
+ const result = await messageModel.createMessageQuery({
1959
+ messageId: 'msg1',
1960
+ userQuery: 'original query',
1961
+ rewriteQuery: 'rewritten query',
1962
+ embeddingsId,
1963
+ });
1964
+
1965
+ // 断言结果
1966
+ expect(result).toBeDefined();
1967
+ expect(result.id).toBeDefined();
1968
+ expect(result.messageId).toBe('msg1');
1969
+ expect(result.userQuery).toBe('original query');
1970
+ expect(result.rewriteQuery).toBe('rewritten query');
1971
+ expect(result.userId).toBe(userId);
1972
+
1973
+ // 验证数据库中的记录
1974
+ const dbResult = await serverDB
1975
+ .select()
1976
+ .from(messageQueries)
1977
+ .where(eq(messageQueries.id, result.id));
1978
+
1979
+ expect(dbResult).toHaveLength(1);
1980
+ expect(dbResult[0].messageId).toBe('msg1');
1981
+ expect(dbResult[0].userQuery).toBe('original query');
1982
+ expect(dbResult[0].rewriteQuery).toBe('rewritten query');
1983
+ });
1984
+
1985
+ it('should create a message query with embeddings ID', async () => {
1986
+ // 创建测试数据
1987
+ await serverDB.insert(messages).values({
1988
+ id: 'msg2',
1989
+ userId,
1990
+ role: 'user',
1991
+ content: 'test message',
1992
+ });
1993
+
1994
+ // 调用 createMessageQuery 方法
1995
+ const result = await messageModel.createMessageQuery({
1996
+ messageId: 'msg2',
1997
+ userQuery: 'test query',
1998
+ rewriteQuery: 'test rewritten query',
1999
+ embeddingsId,
2000
+ });
2001
+
2002
+ // 断言结果
2003
+ expect(result).toBeDefined();
2004
+ expect(result.embeddingsId).toBe(embeddingsId);
2005
+
2006
+ // 验证数据库中的记录
2007
+ const dbResult = await serverDB
2008
+ .select()
2009
+ .from(messageQueries)
2010
+ .where(eq(messageQueries.id, result.id));
2011
+
2012
+ expect(dbResult[0].embeddingsId).toBe(embeddingsId);
2013
+ });
2014
+
2015
+ it('should generate a unique ID for each message query', async () => {
2016
+ // 创建测试数据
2017
+ await serverDB.insert(messages).values({
2018
+ id: 'msg3',
2019
+ userId,
2020
+ role: 'user',
2021
+ content: 'test message',
2022
+ });
2023
+
2024
+ // 连续创建两个消息查询
2025
+ const result1 = await messageModel.createMessageQuery({
2026
+ messageId: 'msg3',
2027
+ userQuery: 'query 1',
2028
+ rewriteQuery: 'rewritten query 1',
2029
+ embeddingsId,
2030
+ });
2031
+
2032
+ const result2 = await messageModel.createMessageQuery({
2033
+ messageId: 'msg3',
2034
+ userQuery: 'query 2',
2035
+ rewriteQuery: 'rewritten query 2',
2036
+ embeddingsId,
2037
+ });
2038
+
2039
+ // 断言结果
2040
+ expect(result1.id).not.toBe(result2.id);
2041
+ });
2042
+ });
2043
+
2044
+ describe('deleteMessageQuery', () => {
2045
+ it('should delete a message query by ID', async () => {
2046
+ // 创建测试数据
2047
+ const queryId = uuid();
2048
+ await serverDB.insert(messages).values({
2049
+ id: 'msg4',
2050
+ userId,
2051
+ role: 'user',
2052
+ content: 'test message',
2053
+ });
2054
+
2055
+ await serverDB.insert(messageQueries).values({
2056
+ id: queryId,
2057
+ messageId: 'msg4',
2058
+ userQuery: 'test query',
2059
+ rewriteQuery: 'rewritten query',
2060
+ userId,
2061
+ });
2062
+
2063
+ // 验证查询已创建
2064
+ const beforeDelete = await serverDB
2065
+ .select()
2066
+ .from(messageQueries)
2067
+ .where(eq(messageQueries.id, queryId));
2068
+
2069
+ expect(beforeDelete).toHaveLength(1);
2070
+
2071
+ // 调用 deleteMessageQuery 方法
2072
+ await messageModel.deleteMessageQuery(queryId);
2073
+
2074
+ // 验证查询已删除
2075
+ const afterDelete = await serverDB
2076
+ .select()
2077
+ .from(messageQueries)
2078
+ .where(eq(messageQueries.id, queryId));
2079
+
2080
+ expect(afterDelete).toHaveLength(0);
2081
+ });
2082
+
2083
+ it('should only delete message queries belonging to the user', async () => {
2084
+ // 创建测试数据 - 其他用户的查询
2085
+ const queryId = uuid();
2086
+ await serverDB.insert(messages).values({
2087
+ id: 'msg5',
2088
+ userId: '456',
2089
+ role: 'user',
2090
+ content: 'test message',
2091
+ });
2092
+
2093
+ await serverDB.insert(messageQueries).values({
2094
+ id: queryId,
2095
+ messageId: 'msg5',
2096
+ userQuery: 'test query',
2097
+ rewriteQuery: 'rewritten query',
2098
+ userId: '456', // 其他用户
2099
+ });
2100
+
2101
+ // 调用 deleteMessageQuery 方法
2102
+ await messageModel.deleteMessageQuery(queryId);
2103
+
2104
+ // 验证查询未被删除
2105
+ const afterDelete = await serverDB
2106
+ .select()
2107
+ .from(messageQueries)
2108
+ .where(eq(messageQueries.id, queryId));
2109
+
2110
+ expect(afterDelete).toHaveLength(1);
2111
+ });
2112
+
2113
+ it('should throw error when deleting non-existent message query', async () => {
2114
+ // 调用 deleteMessageQuery 方法删除不存在的查询
2115
+ try {
2116
+ await messageModel.deleteMessageQuery('non-existent-id');
2117
+ } catch (e) {
2118
+ expect(e).toBeInstanceOf(Error);
2119
+ }
2120
+ });
2121
+ });
1526
2122
  });