@lobehub/chat 1.35.9 → 1.35.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. package/CHANGELOG.md +42 -0
  2. package/changelog/v1.json +14 -0
  3. package/package.json +2 -2
  4. package/src/app/(main)/repos/[id]/@menu/default.tsx +2 -1
  5. package/src/app/(main)/repos/[id]/page.tsx +2 -1
  6. package/src/database/schemas/topic.ts +3 -1
  7. package/src/database/server/core/dbForTest.ts +4 -6
  8. package/src/database/server/models/__tests__/_test_template.ts +3 -9
  9. package/src/database/server/models/__tests__/agent.test.ts +2 -8
  10. package/src/database/server/models/__tests__/asyncTask.test.ts +1 -7
  11. package/src/database/server/models/__tests__/chunk.test.ts +155 -16
  12. package/src/database/server/models/__tests__/file.test.ts +123 -15
  13. package/src/database/server/models/__tests__/knowledgeBase.test.ts +6 -12
  14. package/src/database/server/models/__tests__/message.test.ts +230 -7
  15. package/src/database/server/models/__tests__/nextauth.test.ts +1 -7
  16. package/src/database/server/models/__tests__/plugin.test.ts +1 -7
  17. package/src/database/server/models/__tests__/session.test.ts +169 -11
  18. package/src/database/server/models/__tests__/sessionGroup.test.ts +2 -8
  19. package/src/database/server/models/__tests__/topic.test.ts +1 -7
  20. package/src/database/server/models/__tests__/user.test.ts +55 -20
  21. package/src/database/server/models/_template.ts +10 -8
  22. package/src/database/server/models/agent.ts +17 -13
  23. package/src/database/server/models/asyncTask.ts +11 -9
  24. package/src/database/server/models/chunk.ts +19 -14
  25. package/src/database/server/models/embedding.ts +10 -8
  26. package/src/database/server/models/file.ts +19 -17
  27. package/src/database/server/models/knowledgeBase.ts +14 -12
  28. package/src/database/server/models/message.ts +36 -34
  29. package/src/database/server/models/plugin.ts +10 -8
  30. package/src/database/server/models/session.ts +23 -64
  31. package/src/database/server/models/sessionGroup.ts +11 -9
  32. package/src/database/server/models/thread.ts +11 -9
  33. package/src/database/server/models/topic.ts +19 -22
  34. package/src/database/server/models/user.ts +96 -84
  35. package/src/database/type.ts +7 -0
  36. package/src/libs/next-auth/adapter/index.ts +10 -10
  37. package/src/libs/trpc/async/asyncAuth.ts +2 -1
  38. package/src/server/routers/async/file.ts +5 -4
  39. package/src/server/routers/async/ragEval.ts +4 -3
  40. package/src/server/routers/lambda/_template.ts +2 -1
  41. package/src/server/routers/lambda/agent.ts +6 -5
  42. package/src/server/routers/lambda/chunk.ts +5 -5
  43. package/src/server/routers/lambda/file.ts +4 -3
  44. package/src/server/routers/lambda/knowledgeBase.ts +2 -1
  45. package/src/server/routers/lambda/message.ts +4 -2
  46. package/src/server/routers/lambda/plugin.ts +4 -2
  47. package/src/server/routers/lambda/ragEval.ts +2 -1
  48. package/src/server/routers/lambda/session.ts +4 -3
  49. package/src/server/routers/lambda/sessionGroup.ts +2 -1
  50. package/src/server/routers/lambda/thread.ts +3 -2
  51. package/src/server/routers/lambda/topic.ts +4 -2
  52. package/src/server/routers/lambda/user.ts +10 -9
  53. package/src/server/services/chunk/index.ts +3 -2
  54. package/src/server/services/nextAuthUser/index.ts +3 -3
  55. package/src/server/services/user/index.ts +7 -6
@@ -13,15 +13,9 @@ import { UserModel } from '../user';
13
13
 
14
14
  let serverDB = await getTestDBInstance();
15
15
 
16
- vi.mock('@/database/server/core/db', async () => ({
17
- get serverDB() {
18
- return serverDB;
19
- },
20
- }));
21
-
22
16
  const userId = 'user-db';
23
17
  const userEmail = 'user@example.com';
24
- const userModel = new UserModel();
18
+ const userModel = new UserModel(serverDB, userId);
25
19
 
26
20
  beforeEach(async () => {
27
21
  await serverDB.delete(users);
@@ -44,14 +38,14 @@ describe('UserModel', () => {
44
38
  email: 'test@example.com',
45
39
  };
46
40
 
47
- await UserModel.createUser(params);
41
+ await UserModel.createUser(serverDB, params);
48
42
 
49
43
  const user = await serverDB.query.users.findFirst({ where: eq(users.id, userId) });
50
44
  expect(user).not.toBeNull();
51
45
  expect(user?.username).toBe('testuser');
52
46
  expect(user?.email).toBe('test@example.com');
53
47
 
54
- const sessionModel = new SessionModel(userId);
48
+ const sessionModel = new SessionModel(serverDB, userId);
55
49
  const inbox = await sessionModel.findByIdOrSlug(INBOX_SESSION_ID);
56
50
  expect(inbox).not.toBeNull();
57
51
  });
@@ -61,7 +55,7 @@ describe('UserModel', () => {
61
55
  it('should delete a user', async () => {
62
56
  await serverDB.insert(users).values({ id: userId });
63
57
 
64
- await UserModel.deleteUser(userId);
58
+ await UserModel.deleteUser(serverDB, userId);
65
59
 
66
60
  const user = await serverDB.query.users.findFirst({ where: eq(users.id, userId) });
67
61
  expect(user).toBeUndefined();
@@ -72,7 +66,7 @@ describe('UserModel', () => {
72
66
  it('should find a user by ID', async () => {
73
67
  await serverDB.insert(users).values({ id: userId, username: 'testuser' });
74
68
 
75
- const user = await UserModel.findById(userId);
69
+ const user = await UserModel.findById(serverDB, userId);
76
70
 
77
71
  expect(user).not.toBeNull();
78
72
  expect(user?.id).toBe(userId);
@@ -84,7 +78,7 @@ describe('UserModel', () => {
84
78
  it('should find a user by email', async () => {
85
79
  await serverDB.insert(users).values({ id: userId, email: userEmail });
86
80
 
87
- const user = await UserModel.findByEmail(userEmail);
81
+ const user = await UserModel.findByEmail(serverDB, userEmail);
88
82
 
89
83
  expect(user).not.toBeNull();
90
84
  expect(user?.id).toBe(userId);
@@ -107,7 +101,7 @@ describe('UserModel', () => {
107
101
  keyVaults: encryptedKeyVaults,
108
102
  });
109
103
 
110
- const state = await userModel.getUserState(userId);
104
+ const state = await userModel.getUserState();
111
105
 
112
106
  expect(state.userId).toBe(userId);
113
107
  expect(state.preference).toEqual(preference);
@@ -115,7 +109,9 @@ describe('UserModel', () => {
115
109
  });
116
110
 
117
111
  it('should throw an error if user not found', async () => {
118
- await expect(userModel.getUserState('invalid-user-id')).rejects.toThrow('user not found');
112
+ const userModel = new UserModel(serverDB, 'invalid-user-id');
113
+
114
+ await expect(userModel.getUserState()).rejects.toThrow('user not found');
119
115
  });
120
116
  });
121
117
 
@@ -123,7 +119,7 @@ describe('UserModel', () => {
123
119
  it('should update user fields', async () => {
124
120
  await serverDB.insert(users).values({ id: userId, username: 'oldname' });
125
121
 
126
- await userModel.updateUser(userId, { username: 'newname' });
122
+ await userModel.updateUser({ username: 'newname' });
127
123
 
128
124
  const updatedUser = await serverDB.query.users.findFirst({
129
125
  where: eq(users.id, userId),
@@ -137,7 +133,7 @@ describe('UserModel', () => {
137
133
  await serverDB.insert(users).values({ id: userId });
138
134
  await serverDB.insert(userSettings).values({ id: userId });
139
135
 
140
- await userModel.deleteSetting(userId);
136
+ await userModel.deleteSetting();
141
137
 
142
138
  const settings = await serverDB.query.userSettings.findFirst({
143
139
  where: eq(users.id, userId),
@@ -155,7 +151,7 @@ describe('UserModel', () => {
155
151
  } as UserSettings;
156
152
  await serverDB.insert(users).values({ id: userId });
157
153
 
158
- await userModel.updateSetting(userId, settings);
154
+ await userModel.updateSetting(settings);
159
155
 
160
156
  const updatedSettings = await serverDB.query.userSettings.findFirst({
161
157
  where: eq(users.id, userId),
@@ -178,7 +174,7 @@ describe('UserModel', () => {
178
174
  const newSettings = {
179
175
  general: { fontSize: 16, language: 'zh-CN', themeMode: 'dark' },
180
176
  } as UserSettings;
181
- await userModel.updateSetting(userId, newSettings);
177
+ await userModel.updateSetting(newSettings);
182
178
 
183
179
  const updatedSettings = await serverDB.query.userSettings.findFirst({
184
180
  where: eq(users.id, userId),
@@ -195,7 +191,7 @@ describe('UserModel', () => {
195
191
  const newPreference: Partial<UserPreference> = {
196
192
  guide: { topic: true, moveSettingsToAvatar: true },
197
193
  };
198
- await userModel.updatePreference(userId, newPreference);
194
+ await userModel.updatePreference(newPreference);
199
195
 
200
196
  const updatedUser = await serverDB.query.users.findFirst({ where: eq(users.id, userId) });
201
197
  expect(updatedUser?.preference).toEqual({ ...preference, ...newPreference });
@@ -212,10 +208,49 @@ describe('UserModel', () => {
212
208
  moveSettingsToAvatar: true,
213
209
  uploadFileInKnowledgeBase: true,
214
210
  };
215
- await userModel.updateGuide(userId, newGuide);
211
+ await userModel.updateGuide(newGuide);
216
212
 
217
213
  const updatedUser = await serverDB.query.users.findFirst({ where: eq(users.id, userId) });
218
214
  expect(updatedUser?.preference).toEqual({ ...preference, guide: newGuide });
219
215
  });
220
216
  });
217
+
218
+ describe('getUserApiKeys', () => {
219
+ it('should get and decrypt user API keys', async () => {
220
+ const keyVaults = { openai: { apiKey: 'test-key' } };
221
+ const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey();
222
+ const encryptedKeyVaults = await gateKeeper.encrypt(JSON.stringify(keyVaults));
223
+
224
+ const userId = 'user-api-id';
225
+
226
+ await serverDB.insert(users).values({ id: userId });
227
+ await serverDB.insert(userSettings).values({
228
+ id: userId,
229
+ keyVaults: encryptedKeyVaults,
230
+ });
231
+
232
+ const result = await UserModel.getUserApiKeys(serverDB, userId);
233
+ expect(result).toEqual(keyVaults);
234
+ });
235
+
236
+ it('should throw error when user not found', async () => {
237
+ await expect(UserModel.getUserApiKeys(serverDB, 'non-existent-id')).rejects.toThrow(
238
+ 'user not found',
239
+ );
240
+ });
241
+
242
+ it('should handle decrypt failure and return empty object', async () => {
243
+ const userId = 'user-api-test-id';
244
+ // 模拟解密失败的情况
245
+ const invalidEncryptedData = 'invalid:-encrypted-:data';
246
+ await serverDB.insert(users).values({ id: userId });
247
+ await serverDB.insert(userSettings).values({
248
+ id: userId,
249
+ keyVaults: invalidEncryptedData,
250
+ });
251
+
252
+ const result = await UserModel.getUserApiKeys(serverDB, userId);
253
+ expect(result).toEqual({});
254
+ });
255
+ });
221
256
  });
@@ -1,19 +1,21 @@
1
1
  import { eq } from 'drizzle-orm';
2
2
  import { and, desc } from 'drizzle-orm/expressions';
3
3
 
4
- import { serverDB } from '@/database/server';
4
+ import { LobeChatDatabase } from '@/database/type';
5
5
 
6
6
  import { NewSessionGroup, SessionGroupItem, sessionGroups } from '../../schemas';
7
7
 
8
8
  export class TemplateModel {
9
9
  private userId: string;
10
+ private db: LobeChatDatabase;
10
11
 
11
- constructor(userId: string) {
12
+ constructor(db: LobeChatDatabase, userId: string) {
12
13
  this.userId = userId;
14
+ this.db = db;
13
15
  }
14
16
 
15
17
  create = async (params: NewSessionGroup) => {
16
- const [result] = await serverDB
18
+ const [result] = await this.db
17
19
  .insert(sessionGroups)
18
20
  .values({ ...params, userId: this.userId })
19
21
  .returning();
@@ -22,30 +24,30 @@ export class TemplateModel {
22
24
  };
23
25
 
24
26
  delete = async (id: string) => {
25
- return serverDB
27
+ return this.db
26
28
  .delete(sessionGroups)
27
29
  .where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)));
28
30
  };
29
31
 
30
32
  deleteAll = async () => {
31
- return serverDB.delete(sessionGroups).where(eq(sessionGroups.userId, this.userId));
33
+ return this.db.delete(sessionGroups).where(eq(sessionGroups.userId, this.userId));
32
34
  };
33
35
 
34
36
  query = async () => {
35
- return serverDB.query.sessionGroups.findMany({
37
+ return this.db.query.sessionGroups.findMany({
36
38
  orderBy: [desc(sessionGroups.updatedAt)],
37
39
  where: eq(sessionGroups.userId, this.userId),
38
40
  });
39
41
  };
40
42
 
41
43
  findById = async (id: string) => {
42
- return serverDB.query.sessionGroups.findFirst({
44
+ return this.db.query.sessionGroups.findFirst({
43
45
  where: and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)),
44
46
  });
45
47
  };
46
48
 
47
49
  async update(id: string, value: Partial<SessionGroupItem>) {
48
- return serverDB
50
+ return this.db
49
51
  .update(sessionGroups)
50
52
  .set({ ...value, updatedAt: new Date() })
51
53
  .where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)));
@@ -1,7 +1,8 @@
1
1
  import { inArray } from 'drizzle-orm';
2
2
  import { and, desc, eq } from 'drizzle-orm/expressions';
3
3
 
4
- import { serverDB } from '@/database/server';
4
+ import { LobeChatDatabase } from '@/database/type';
5
+
5
6
  import {
6
7
  agents,
7
8
  agentsFiles,
@@ -13,12 +14,15 @@ import {
13
14
 
14
15
  export class AgentModel {
15
16
  private userId: string;
16
- constructor(userId: string) {
17
+ private db: LobeChatDatabase;
18
+
19
+ constructor(db: LobeChatDatabase, userId: string) {
17
20
  this.userId = userId;
21
+ this.db = db;
18
22
  }
19
23
 
20
24
  async getAgentConfigById(id: string) {
21
- const agent = await serverDB.query.agents.findFirst({ where: eq(agents.id, id) });
25
+ const agent = await this.db.query.agents.findFirst({ where: eq(agents.id, id) });
22
26
 
23
27
  const knowledge = await this.getAgentAssignedKnowledge(id);
24
28
 
@@ -26,14 +30,14 @@ export class AgentModel {
26
30
  }
27
31
 
28
32
  async getAgentAssignedKnowledge(id: string) {
29
- const knowledgeBaseResult = await serverDB
33
+ const knowledgeBaseResult = await this.db
30
34
  .select({ enabled: agentsKnowledgeBases.enabled, knowledgeBases })
31
35
  .from(agentsKnowledgeBases)
32
36
  .where(eq(agentsKnowledgeBases.agentId, id))
33
37
  .orderBy(desc(agentsKnowledgeBases.createdAt))
34
38
  .leftJoin(knowledgeBases, eq(knowledgeBases.id, agentsKnowledgeBases.knowledgeBaseId));
35
39
 
36
- const fileResult = await serverDB
40
+ const fileResult = await this.db
37
41
  .select({ enabled: agentsFiles.enabled, files })
38
42
  .from(agentsFiles)
39
43
  .where(eq(agentsFiles.agentId, id))
@@ -56,7 +60,7 @@ export class AgentModel {
56
60
  * Find agent by session id
57
61
  */
58
62
  async findBySessionId(sessionId: string) {
59
- const item = await serverDB.query.agentsToSessions.findFirst({
63
+ const item = await this.db.query.agentsToSessions.findFirst({
60
64
  where: eq(agentsToSessions.sessionId, sessionId),
61
65
  });
62
66
  if (!item) return;
@@ -71,7 +75,7 @@ export class AgentModel {
71
75
  knowledgeBaseId: string,
72
76
  enabled: boolean = true,
73
77
  ) => {
74
- return serverDB
78
+ return this.db
75
79
  .insert(agentsKnowledgeBases)
76
80
  .values({
77
81
  agentId,
@@ -83,7 +87,7 @@ export class AgentModel {
83
87
  };
84
88
 
85
89
  deleteAgentKnowledgeBase = async (agentId: string, knowledgeBaseId: string) => {
86
- return serverDB
90
+ return this.db
87
91
  .delete(agentsKnowledgeBases)
88
92
  .where(
89
93
  and(
@@ -96,7 +100,7 @@ export class AgentModel {
96
100
  };
97
101
 
98
102
  toggleKnowledgeBase = async (agentId: string, knowledgeBaseId: string, enabled?: boolean) => {
99
- return serverDB
103
+ return this.db
100
104
  .update(agentsKnowledgeBases)
101
105
  .set({ enabled })
102
106
  .where(
@@ -111,7 +115,7 @@ export class AgentModel {
111
115
 
112
116
  createAgentFiles = async (agentId: string, fileIds: string[], enabled: boolean = true) => {
113
117
  // Exclude the fileIds that already exist in agentsFiles, and then insert them
114
- const existingFiles = await serverDB
118
+ const existingFiles = await this.db
115
119
  .select({ id: agentsFiles.fileId })
116
120
  .from(agentsFiles)
117
121
  .where(
@@ -128,7 +132,7 @@ export class AgentModel {
128
132
 
129
133
  if (needToInsertFileIds.length === 0) return;
130
134
 
131
- return serverDB
135
+ return this.db
132
136
  .insert(agentsFiles)
133
137
  .values(
134
138
  needToInsertFileIds.map((fileId) => ({ agentId, enabled, fileId, userId: this.userId })),
@@ -137,7 +141,7 @@ export class AgentModel {
137
141
  };
138
142
 
139
143
  deleteAgentFile = async (agentId: string, fileId: string) => {
140
- return serverDB
144
+ return this.db
141
145
  .delete(agentsFiles)
142
146
  .where(
143
147
  and(
@@ -150,7 +154,7 @@ export class AgentModel {
150
154
  };
151
155
 
152
156
  toggleFile = async (agentId: string, fileId: string, enabled?: boolean) => {
153
- return serverDB
157
+ return this.db
154
158
  .update(agentsFiles)
155
159
  .set({ enabled })
156
160
  .where(
@@ -1,7 +1,7 @@
1
1
  import { eq, inArray, lt } from 'drizzle-orm';
2
2
  import { and } from 'drizzle-orm/expressions';
3
3
 
4
- import { serverDB } from '@/database/server';
4
+ import { LobeChatDatabase } from '@/database/type';
5
5
  import {
6
6
  AsyncTaskError,
7
7
  AsyncTaskErrorType,
@@ -16,13 +16,15 @@ export const ASYNC_TASK_TIMEOUT = 298 * 1000;
16
16
 
17
17
  export class AsyncTaskModel {
18
18
  private userId: string;
19
+ private db: LobeChatDatabase;
19
20
 
20
- constructor(userId: string) {
21
+ constructor(db: LobeChatDatabase, userId: string) {
21
22
  this.userId = userId;
23
+ this.db = db;
22
24
  }
23
25
 
24
26
  create = async (params: Pick<NewAsyncTaskItem, 'type' | 'status'>): Promise<string> => {
25
- const data = await serverDB
27
+ const data = await this.db
26
28
  .insert(asyncTasks)
27
29
  .values({ ...params, userId: this.userId })
28
30
  .returning();
@@ -31,17 +33,17 @@ export class AsyncTaskModel {
31
33
  };
32
34
 
33
35
  delete = async (id: string) => {
34
- return serverDB
36
+ return this.db
35
37
  .delete(asyncTasks)
36
38
  .where(and(eq(asyncTasks.id, id), eq(asyncTasks.userId, this.userId)));
37
39
  };
38
40
 
39
41
  findById = async (id: string) => {
40
- return serverDB.query.asyncTasks.findFirst({ where: and(eq(asyncTasks.id, id)) });
42
+ return this.db.query.asyncTasks.findFirst({ where: and(eq(asyncTasks.id, id)) });
41
43
  };
42
44
 
43
45
  update(taskId: string, value: Partial<AsyncTaskSelectItem>) {
44
- return serverDB
46
+ return this.db
45
47
  .update(asyncTasks)
46
48
  .set({ ...value, updatedAt: new Date() })
47
49
  .where(and(eq(asyncTasks.id, taskId)));
@@ -52,7 +54,7 @@ export class AsyncTaskModel {
52
54
 
53
55
  if (taskIds.length > 0) {
54
56
  await this.checkTimeoutTasks(taskIds);
55
- chunkTasks = await serverDB.query.asyncTasks.findMany({
57
+ chunkTasks = await this.db.query.asyncTasks.findMany({
56
58
  where: and(inArray(asyncTasks.id, taskIds), eq(asyncTasks.type, type)),
57
59
  });
58
60
  }
@@ -64,7 +66,7 @@ export class AsyncTaskModel {
64
66
  * make the task status to be `error` if the task is not finished in 20 seconds
65
67
  */
66
68
  async checkTimeoutTasks(ids: string[]) {
67
- const tasks = await serverDB
69
+ const tasks = await this.db
68
70
  .select({ id: asyncTasks.id })
69
71
  .from(asyncTasks)
70
72
  .where(
@@ -76,7 +78,7 @@ export class AsyncTaskModel {
76
78
  );
77
79
 
78
80
  if (tasks.length > 0) {
79
- await serverDB
81
+ await this.db
80
82
  .update(asyncTasks)
81
83
  .set({
82
84
  error: new AsyncTaskError(
@@ -2,7 +2,7 @@ import { asc, cosineDistance, count, eq, inArray, sql } from 'drizzle-orm';
2
2
  import { and, desc, isNull } from 'drizzle-orm/expressions';
3
3
  import { chunk } from 'lodash-es';
4
4
 
5
- import { serverDB } from '@/database/server';
5
+ import { LobeChatDatabase } from '@/database/type';
6
6
  import { ChunkMetadata, FileChunk } from '@/types/chunk';
7
7
 
8
8
  import {
@@ -18,12 +18,17 @@ import {
18
18
  export class ChunkModel {
19
19
  private userId: string;
20
20
 
21
- constructor(userId: string) {
21
+ private db: LobeChatDatabase;
22
+
23
+ constructor(db: LobeChatDatabase, userId: string) {
22
24
  this.userId = userId;
25
+ this.db = db;
23
26
  }
24
27
 
25
28
  bulkCreate = async (params: NewChunkItem[], fileId: string) => {
26
- return serverDB.transaction(async (trx) => {
29
+ return this.db.transaction(async (trx) => {
30
+ if (params.length === 0) return [];
31
+
27
32
  const result = await trx.insert(chunks).values(params).returning();
28
33
 
29
34
  const fileChunksData = result.map((chunk) => ({ chunkId: chunk.id, fileId }));
@@ -37,15 +42,15 @@ export class ChunkModel {
37
42
  };
38
43
 
39
44
  bulkCreateUnstructuredChunks = async (params: NewUnstructuredChunkItem[]) => {
40
- return serverDB.insert(unstructuredChunks).values(params);
45
+ return this.db.insert(unstructuredChunks).values(params);
41
46
  };
42
47
 
43
48
  delete = async (id: string) => {
44
- return serverDB.delete(chunks).where(and(eq(chunks.id, id), eq(chunks.userId, this.userId)));
49
+ return this.db.delete(chunks).where(and(eq(chunks.id, id), eq(chunks.userId, this.userId)));
45
50
  };
46
51
 
47
52
  deleteOrphanChunks = async () => {
48
- const orphanedChunks = await serverDB
53
+ const orphanedChunks = await this.db
49
54
  .select({ chunkId: chunks.id })
50
55
  .from(chunks)
51
56
  .leftJoin(fileChunks, eq(chunks.id, fileChunks.chunkId))
@@ -56,7 +61,7 @@ export class ChunkModel {
56
61
 
57
62
  const list = chunk(ids, 500);
58
63
 
59
- await serverDB.transaction(async (trx) => {
64
+ await this.db.transaction(async (trx) => {
60
65
  await Promise.all(
61
66
  list.map(async (chunkIds) => {
62
67
  await trx.delete(chunks).where(inArray(chunks.id, chunkIds));
@@ -66,13 +71,13 @@ export class ChunkModel {
66
71
  };
67
72
 
68
73
  findById = async (id: string) => {
69
- return serverDB.query.chunks.findFirst({
74
+ return this.db.query.chunks.findFirst({
70
75
  where: and(eq(chunks.id, id)),
71
76
  });
72
77
  };
73
78
 
74
79
  async findByFileId(id: string, page = 0) {
75
- const data = await serverDB
80
+ const data = await this.db
76
81
  .select({
77
82
  abstract: chunks.abstract,
78
83
  createdAt: chunks.createdAt,
@@ -98,7 +103,7 @@ export class ChunkModel {
98
103
  }
99
104
 
100
105
  async getChunksTextByFileId(id: string): Promise<{ id: string; text: string }[]> {
101
- const data = await serverDB
106
+ const data = await this.db
102
107
  .select()
103
108
  .from(chunks)
104
109
  .innerJoin(fileChunks, eq(chunks.id, fileChunks.chunkId))
@@ -113,7 +118,7 @@ export class ChunkModel {
113
118
  async countByFileIds(ids: string[]) {
114
119
  if (ids.length === 0) return [];
115
120
 
116
- return serverDB
121
+ return this.db
117
122
  .select({
118
123
  count: count(fileChunks.chunkId),
119
124
  id: fileChunks.fileId,
@@ -124,7 +129,7 @@ export class ChunkModel {
124
129
  }
125
130
 
126
131
  async countByFileId(ids: string) {
127
- const data = await serverDB
132
+ const data = await this.db
128
133
  .select({
129
134
  count: count(fileChunks.chunkId),
130
135
  id: fileChunks.fileId,
@@ -146,7 +151,7 @@ export class ChunkModel {
146
151
  }) {
147
152
  const similarity = sql<number>`1 - (${cosineDistance(embeddings.embeddings, embedding)})`;
148
153
 
149
- const data = await serverDB
154
+ const data = await this.db
150
155
  .select({
151
156
  fileId: fileChunks.fileId,
152
157
  fileName: files.name,
@@ -185,7 +190,7 @@ export class ChunkModel {
185
190
 
186
191
  if (!hasFiles) return [];
187
192
 
188
- const result = await serverDB
193
+ const result = await this.db
189
194
  .select({
190
195
  fileId: files.id,
191
196
  fileName: files.name,
@@ -1,19 +1,21 @@
1
1
  import { count, eq } from 'drizzle-orm';
2
2
  import { and } from 'drizzle-orm/expressions';
3
3
 
4
- import { serverDB } from '@/database/server';
4
+ import { LobeChatDatabase } from '@/database/type';
5
5
 
6
6
  import { NewEmbeddingsItem, embeddings } from '../../schemas';
7
7
 
8
8
  export class EmbeddingModel {
9
9
  private userId: string;
10
+ private db: LobeChatDatabase;
10
11
 
11
- constructor(userId: string) {
12
+ constructor(db: LobeChatDatabase, userId: string) {
12
13
  this.userId = userId;
14
+ this.db = db;
13
15
  }
14
16
 
15
17
  create = async (value: Omit<NewEmbeddingsItem, 'userId'>) => {
16
- const [item] = await serverDB
18
+ const [item] = await this.db
17
19
  .insert(embeddings)
18
20
  .values({ ...value, userId: this.userId })
19
21
  .returning();
@@ -22,7 +24,7 @@ export class EmbeddingModel {
22
24
  };
23
25
 
24
26
  bulkCreate = async (values: Omit<NewEmbeddingsItem, 'userId'>[]) => {
25
- return serverDB
27
+ return this.db
26
28
  .insert(embeddings)
27
29
  .values(values.map((item) => ({ ...item, userId: this.userId })))
28
30
  .onConflictDoNothing({
@@ -31,25 +33,25 @@ export class EmbeddingModel {
31
33
  };
32
34
 
33
35
  delete = async (id: string) => {
34
- return serverDB
36
+ return this.db
35
37
  .delete(embeddings)
36
38
  .where(and(eq(embeddings.id, id), eq(embeddings.userId, this.userId)));
37
39
  };
38
40
 
39
41
  query = async () => {
40
- return serverDB.query.embeddings.findMany({
42
+ return this.db.query.embeddings.findMany({
41
43
  where: eq(embeddings.userId, this.userId),
42
44
  });
43
45
  };
44
46
 
45
47
  findById = async (id: string) => {
46
- return serverDB.query.embeddings.findFirst({
48
+ return this.db.query.embeddings.findFirst({
47
49
  where: and(eq(embeddings.id, id), eq(embeddings.userId, this.userId)),
48
50
  });
49
51
  };
50
52
 
51
53
  countUsage = async () => {
52
- const result = await serverDB
54
+ const result = await this.db
53
55
  .select({
54
56
  count: count(),
55
57
  })