@lobehub/chat 1.35.9 → 1.35.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +42 -0
- package/changelog/v1.json +14 -0
- package/package.json +2 -2
- package/src/app/(main)/repos/[id]/@menu/default.tsx +2 -1
- package/src/app/(main)/repos/[id]/page.tsx +2 -1
- package/src/database/schemas/topic.ts +3 -1
- package/src/database/server/core/dbForTest.ts +4 -6
- package/src/database/server/models/__tests__/_test_template.ts +3 -9
- package/src/database/server/models/__tests__/agent.test.ts +2 -8
- package/src/database/server/models/__tests__/asyncTask.test.ts +1 -7
- package/src/database/server/models/__tests__/chunk.test.ts +155 -16
- package/src/database/server/models/__tests__/file.test.ts +123 -15
- package/src/database/server/models/__tests__/knowledgeBase.test.ts +6 -12
- package/src/database/server/models/__tests__/message.test.ts +230 -7
- package/src/database/server/models/__tests__/nextauth.test.ts +1 -7
- package/src/database/server/models/__tests__/plugin.test.ts +1 -7
- package/src/database/server/models/__tests__/session.test.ts +169 -11
- package/src/database/server/models/__tests__/sessionGroup.test.ts +2 -8
- package/src/database/server/models/__tests__/topic.test.ts +1 -7
- package/src/database/server/models/__tests__/user.test.ts +55 -20
- package/src/database/server/models/_template.ts +10 -8
- package/src/database/server/models/agent.ts +17 -13
- package/src/database/server/models/asyncTask.ts +11 -9
- package/src/database/server/models/chunk.ts +19 -14
- package/src/database/server/models/embedding.ts +10 -8
- package/src/database/server/models/file.ts +19 -17
- package/src/database/server/models/knowledgeBase.ts +14 -12
- package/src/database/server/models/message.ts +36 -34
- package/src/database/server/models/plugin.ts +10 -8
- package/src/database/server/models/session.ts +23 -64
- package/src/database/server/models/sessionGroup.ts +11 -9
- package/src/database/server/models/thread.ts +11 -9
- package/src/database/server/models/topic.ts +19 -22
- package/src/database/server/models/user.ts +96 -84
- package/src/database/type.ts +7 -0
- package/src/libs/next-auth/adapter/index.ts +10 -10
- package/src/libs/trpc/async/asyncAuth.ts +2 -1
- package/src/server/routers/async/file.ts +5 -4
- package/src/server/routers/async/ragEval.ts +4 -3
- package/src/server/routers/lambda/_template.ts +2 -1
- package/src/server/routers/lambda/agent.ts +6 -5
- package/src/server/routers/lambda/chunk.ts +5 -5
- package/src/server/routers/lambda/file.ts +4 -3
- package/src/server/routers/lambda/knowledgeBase.ts +2 -1
- package/src/server/routers/lambda/message.ts +4 -2
- package/src/server/routers/lambda/plugin.ts +4 -2
- package/src/server/routers/lambda/ragEval.ts +2 -1
- package/src/server/routers/lambda/session.ts +4 -3
- package/src/server/routers/lambda/sessionGroup.ts +2 -1
- package/src/server/routers/lambda/thread.ts +3 -2
- package/src/server/routers/lambda/topic.ts +4 -2
- package/src/server/routers/lambda/user.ts +10 -9
- package/src/server/services/chunk/index.ts +3 -2
- package/src/server/services/nextAuthUser/index.ts +3 -3
- package/src/server/services/user/index.ts +7 -6
@@ -13,15 +13,9 @@ import { UserModel } from '../user';
|
|
13
13
|
|
14
14
|
let serverDB = await getTestDBInstance();
|
15
15
|
|
16
|
-
vi.mock('@/database/server/core/db', async () => ({
|
17
|
-
get serverDB() {
|
18
|
-
return serverDB;
|
19
|
-
},
|
20
|
-
}));
|
21
|
-
|
22
16
|
const userId = 'user-db';
|
23
17
|
const userEmail = 'user@example.com';
|
24
|
-
const userModel = new UserModel();
|
18
|
+
const userModel = new UserModel(serverDB, userId);
|
25
19
|
|
26
20
|
beforeEach(async () => {
|
27
21
|
await serverDB.delete(users);
|
@@ -44,14 +38,14 @@ describe('UserModel', () => {
|
|
44
38
|
email: 'test@example.com',
|
45
39
|
};
|
46
40
|
|
47
|
-
await UserModel.createUser(params);
|
41
|
+
await UserModel.createUser(serverDB, params);
|
48
42
|
|
49
43
|
const user = await serverDB.query.users.findFirst({ where: eq(users.id, userId) });
|
50
44
|
expect(user).not.toBeNull();
|
51
45
|
expect(user?.username).toBe('testuser');
|
52
46
|
expect(user?.email).toBe('test@example.com');
|
53
47
|
|
54
|
-
const sessionModel = new SessionModel(userId);
|
48
|
+
const sessionModel = new SessionModel(serverDB, userId);
|
55
49
|
const inbox = await sessionModel.findByIdOrSlug(INBOX_SESSION_ID);
|
56
50
|
expect(inbox).not.toBeNull();
|
57
51
|
});
|
@@ -61,7 +55,7 @@ describe('UserModel', () => {
|
|
61
55
|
it('should delete a user', async () => {
|
62
56
|
await serverDB.insert(users).values({ id: userId });
|
63
57
|
|
64
|
-
await UserModel.deleteUser(userId);
|
58
|
+
await UserModel.deleteUser(serverDB, userId);
|
65
59
|
|
66
60
|
const user = await serverDB.query.users.findFirst({ where: eq(users.id, userId) });
|
67
61
|
expect(user).toBeUndefined();
|
@@ -72,7 +66,7 @@ describe('UserModel', () => {
|
|
72
66
|
it('should find a user by ID', async () => {
|
73
67
|
await serverDB.insert(users).values({ id: userId, username: 'testuser' });
|
74
68
|
|
75
|
-
const user = await UserModel.findById(userId);
|
69
|
+
const user = await UserModel.findById(serverDB, userId);
|
76
70
|
|
77
71
|
expect(user).not.toBeNull();
|
78
72
|
expect(user?.id).toBe(userId);
|
@@ -84,7 +78,7 @@ describe('UserModel', () => {
|
|
84
78
|
it('should find a user by email', async () => {
|
85
79
|
await serverDB.insert(users).values({ id: userId, email: userEmail });
|
86
80
|
|
87
|
-
const user = await UserModel.findByEmail(userEmail);
|
81
|
+
const user = await UserModel.findByEmail(serverDB, userEmail);
|
88
82
|
|
89
83
|
expect(user).not.toBeNull();
|
90
84
|
expect(user?.id).toBe(userId);
|
@@ -107,7 +101,7 @@ describe('UserModel', () => {
|
|
107
101
|
keyVaults: encryptedKeyVaults,
|
108
102
|
});
|
109
103
|
|
110
|
-
const state = await userModel.getUserState(
|
104
|
+
const state = await userModel.getUserState();
|
111
105
|
|
112
106
|
expect(state.userId).toBe(userId);
|
113
107
|
expect(state.preference).toEqual(preference);
|
@@ -115,7 +109,9 @@ describe('UserModel', () => {
|
|
115
109
|
});
|
116
110
|
|
117
111
|
it('should throw an error if user not found', async () => {
|
118
|
-
|
112
|
+
const userModel = new UserModel(serverDB, 'invalid-user-id');
|
113
|
+
|
114
|
+
await expect(userModel.getUserState()).rejects.toThrow('user not found');
|
119
115
|
});
|
120
116
|
});
|
121
117
|
|
@@ -123,7 +119,7 @@ describe('UserModel', () => {
|
|
123
119
|
it('should update user fields', async () => {
|
124
120
|
await serverDB.insert(users).values({ id: userId, username: 'oldname' });
|
125
121
|
|
126
|
-
await userModel.updateUser(
|
122
|
+
await userModel.updateUser({ username: 'newname' });
|
127
123
|
|
128
124
|
const updatedUser = await serverDB.query.users.findFirst({
|
129
125
|
where: eq(users.id, userId),
|
@@ -137,7 +133,7 @@ describe('UserModel', () => {
|
|
137
133
|
await serverDB.insert(users).values({ id: userId });
|
138
134
|
await serverDB.insert(userSettings).values({ id: userId });
|
139
135
|
|
140
|
-
await userModel.deleteSetting(
|
136
|
+
await userModel.deleteSetting();
|
141
137
|
|
142
138
|
const settings = await serverDB.query.userSettings.findFirst({
|
143
139
|
where: eq(users.id, userId),
|
@@ -155,7 +151,7 @@ describe('UserModel', () => {
|
|
155
151
|
} as UserSettings;
|
156
152
|
await serverDB.insert(users).values({ id: userId });
|
157
153
|
|
158
|
-
await userModel.updateSetting(
|
154
|
+
await userModel.updateSetting(settings);
|
159
155
|
|
160
156
|
const updatedSettings = await serverDB.query.userSettings.findFirst({
|
161
157
|
where: eq(users.id, userId),
|
@@ -178,7 +174,7 @@ describe('UserModel', () => {
|
|
178
174
|
const newSettings = {
|
179
175
|
general: { fontSize: 16, language: 'zh-CN', themeMode: 'dark' },
|
180
176
|
} as UserSettings;
|
181
|
-
await userModel.updateSetting(
|
177
|
+
await userModel.updateSetting(newSettings);
|
182
178
|
|
183
179
|
const updatedSettings = await serverDB.query.userSettings.findFirst({
|
184
180
|
where: eq(users.id, userId),
|
@@ -195,7 +191,7 @@ describe('UserModel', () => {
|
|
195
191
|
const newPreference: Partial<UserPreference> = {
|
196
192
|
guide: { topic: true, moveSettingsToAvatar: true },
|
197
193
|
};
|
198
|
-
await userModel.updatePreference(
|
194
|
+
await userModel.updatePreference(newPreference);
|
199
195
|
|
200
196
|
const updatedUser = await serverDB.query.users.findFirst({ where: eq(users.id, userId) });
|
201
197
|
expect(updatedUser?.preference).toEqual({ ...preference, ...newPreference });
|
@@ -212,10 +208,49 @@ describe('UserModel', () => {
|
|
212
208
|
moveSettingsToAvatar: true,
|
213
209
|
uploadFileInKnowledgeBase: true,
|
214
210
|
};
|
215
|
-
await userModel.updateGuide(
|
211
|
+
await userModel.updateGuide(newGuide);
|
216
212
|
|
217
213
|
const updatedUser = await serverDB.query.users.findFirst({ where: eq(users.id, userId) });
|
218
214
|
expect(updatedUser?.preference).toEqual({ ...preference, guide: newGuide });
|
219
215
|
});
|
220
216
|
});
|
217
|
+
|
218
|
+
describe('getUserApiKeys', () => {
|
219
|
+
it('should get and decrypt user API keys', async () => {
|
220
|
+
const keyVaults = { openai: { apiKey: 'test-key' } };
|
221
|
+
const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey();
|
222
|
+
const encryptedKeyVaults = await gateKeeper.encrypt(JSON.stringify(keyVaults));
|
223
|
+
|
224
|
+
const userId = 'user-api-id';
|
225
|
+
|
226
|
+
await serverDB.insert(users).values({ id: userId });
|
227
|
+
await serverDB.insert(userSettings).values({
|
228
|
+
id: userId,
|
229
|
+
keyVaults: encryptedKeyVaults,
|
230
|
+
});
|
231
|
+
|
232
|
+
const result = await UserModel.getUserApiKeys(serverDB, userId);
|
233
|
+
expect(result).toEqual(keyVaults);
|
234
|
+
});
|
235
|
+
|
236
|
+
it('should throw error when user not found', async () => {
|
237
|
+
await expect(UserModel.getUserApiKeys(serverDB, 'non-existent-id')).rejects.toThrow(
|
238
|
+
'user not found',
|
239
|
+
);
|
240
|
+
});
|
241
|
+
|
242
|
+
it('should handle decrypt failure and return empty object', async () => {
|
243
|
+
const userId = 'user-api-test-id';
|
244
|
+
// 模拟解密失败的情况
|
245
|
+
const invalidEncryptedData = 'invalid:-encrypted-:data';
|
246
|
+
await serverDB.insert(users).values({ id: userId });
|
247
|
+
await serverDB.insert(userSettings).values({
|
248
|
+
id: userId,
|
249
|
+
keyVaults: invalidEncryptedData,
|
250
|
+
});
|
251
|
+
|
252
|
+
const result = await UserModel.getUserApiKeys(serverDB, userId);
|
253
|
+
expect(result).toEqual({});
|
254
|
+
});
|
255
|
+
});
|
221
256
|
});
|
@@ -1,19 +1,21 @@
|
|
1
1
|
import { eq } from 'drizzle-orm';
|
2
2
|
import { and, desc } from 'drizzle-orm/expressions';
|
3
3
|
|
4
|
-
import {
|
4
|
+
import { LobeChatDatabase } from '@/database/type';
|
5
5
|
|
6
6
|
import { NewSessionGroup, SessionGroupItem, sessionGroups } from '../../schemas';
|
7
7
|
|
8
8
|
export class TemplateModel {
|
9
9
|
private userId: string;
|
10
|
+
private db: LobeChatDatabase;
|
10
11
|
|
11
|
-
constructor(userId: string) {
|
12
|
+
constructor(db: LobeChatDatabase, userId: string) {
|
12
13
|
this.userId = userId;
|
14
|
+
this.db = db;
|
13
15
|
}
|
14
16
|
|
15
17
|
create = async (params: NewSessionGroup) => {
|
16
|
-
const [result] = await
|
18
|
+
const [result] = await this.db
|
17
19
|
.insert(sessionGroups)
|
18
20
|
.values({ ...params, userId: this.userId })
|
19
21
|
.returning();
|
@@ -22,30 +24,30 @@ export class TemplateModel {
|
|
22
24
|
};
|
23
25
|
|
24
26
|
delete = async (id: string) => {
|
25
|
-
return
|
27
|
+
return this.db
|
26
28
|
.delete(sessionGroups)
|
27
29
|
.where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)));
|
28
30
|
};
|
29
31
|
|
30
32
|
deleteAll = async () => {
|
31
|
-
return
|
33
|
+
return this.db.delete(sessionGroups).where(eq(sessionGroups.userId, this.userId));
|
32
34
|
};
|
33
35
|
|
34
36
|
query = async () => {
|
35
|
-
return
|
37
|
+
return this.db.query.sessionGroups.findMany({
|
36
38
|
orderBy: [desc(sessionGroups.updatedAt)],
|
37
39
|
where: eq(sessionGroups.userId, this.userId),
|
38
40
|
});
|
39
41
|
};
|
40
42
|
|
41
43
|
findById = async (id: string) => {
|
42
|
-
return
|
44
|
+
return this.db.query.sessionGroups.findFirst({
|
43
45
|
where: and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)),
|
44
46
|
});
|
45
47
|
};
|
46
48
|
|
47
49
|
async update(id: string, value: Partial<SessionGroupItem>) {
|
48
|
-
return
|
50
|
+
return this.db
|
49
51
|
.update(sessionGroups)
|
50
52
|
.set({ ...value, updatedAt: new Date() })
|
51
53
|
.where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)));
|
@@ -1,7 +1,8 @@
|
|
1
1
|
import { inArray } from 'drizzle-orm';
|
2
2
|
import { and, desc, eq } from 'drizzle-orm/expressions';
|
3
3
|
|
4
|
-
import {
|
4
|
+
import { LobeChatDatabase } from '@/database/type';
|
5
|
+
|
5
6
|
import {
|
6
7
|
agents,
|
7
8
|
agentsFiles,
|
@@ -13,12 +14,15 @@ import {
|
|
13
14
|
|
14
15
|
export class AgentModel {
|
15
16
|
private userId: string;
|
16
|
-
|
17
|
+
private db: LobeChatDatabase;
|
18
|
+
|
19
|
+
constructor(db: LobeChatDatabase, userId: string) {
|
17
20
|
this.userId = userId;
|
21
|
+
this.db = db;
|
18
22
|
}
|
19
23
|
|
20
24
|
async getAgentConfigById(id: string) {
|
21
|
-
const agent = await
|
25
|
+
const agent = await this.db.query.agents.findFirst({ where: eq(agents.id, id) });
|
22
26
|
|
23
27
|
const knowledge = await this.getAgentAssignedKnowledge(id);
|
24
28
|
|
@@ -26,14 +30,14 @@ export class AgentModel {
|
|
26
30
|
}
|
27
31
|
|
28
32
|
async getAgentAssignedKnowledge(id: string) {
|
29
|
-
const knowledgeBaseResult = await
|
33
|
+
const knowledgeBaseResult = await this.db
|
30
34
|
.select({ enabled: agentsKnowledgeBases.enabled, knowledgeBases })
|
31
35
|
.from(agentsKnowledgeBases)
|
32
36
|
.where(eq(agentsKnowledgeBases.agentId, id))
|
33
37
|
.orderBy(desc(agentsKnowledgeBases.createdAt))
|
34
38
|
.leftJoin(knowledgeBases, eq(knowledgeBases.id, agentsKnowledgeBases.knowledgeBaseId));
|
35
39
|
|
36
|
-
const fileResult = await
|
40
|
+
const fileResult = await this.db
|
37
41
|
.select({ enabled: agentsFiles.enabled, files })
|
38
42
|
.from(agentsFiles)
|
39
43
|
.where(eq(agentsFiles.agentId, id))
|
@@ -56,7 +60,7 @@ export class AgentModel {
|
|
56
60
|
* Find agent by session id
|
57
61
|
*/
|
58
62
|
async findBySessionId(sessionId: string) {
|
59
|
-
const item = await
|
63
|
+
const item = await this.db.query.agentsToSessions.findFirst({
|
60
64
|
where: eq(agentsToSessions.sessionId, sessionId),
|
61
65
|
});
|
62
66
|
if (!item) return;
|
@@ -71,7 +75,7 @@ export class AgentModel {
|
|
71
75
|
knowledgeBaseId: string,
|
72
76
|
enabled: boolean = true,
|
73
77
|
) => {
|
74
|
-
return
|
78
|
+
return this.db
|
75
79
|
.insert(agentsKnowledgeBases)
|
76
80
|
.values({
|
77
81
|
agentId,
|
@@ -83,7 +87,7 @@ export class AgentModel {
|
|
83
87
|
};
|
84
88
|
|
85
89
|
deleteAgentKnowledgeBase = async (agentId: string, knowledgeBaseId: string) => {
|
86
|
-
return
|
90
|
+
return this.db
|
87
91
|
.delete(agentsKnowledgeBases)
|
88
92
|
.where(
|
89
93
|
and(
|
@@ -96,7 +100,7 @@ export class AgentModel {
|
|
96
100
|
};
|
97
101
|
|
98
102
|
toggleKnowledgeBase = async (agentId: string, knowledgeBaseId: string, enabled?: boolean) => {
|
99
|
-
return
|
103
|
+
return this.db
|
100
104
|
.update(agentsKnowledgeBases)
|
101
105
|
.set({ enabled })
|
102
106
|
.where(
|
@@ -111,7 +115,7 @@ export class AgentModel {
|
|
111
115
|
|
112
116
|
createAgentFiles = async (agentId: string, fileIds: string[], enabled: boolean = true) => {
|
113
117
|
// Exclude the fileIds that already exist in agentsFiles, and then insert them
|
114
|
-
const existingFiles = await
|
118
|
+
const existingFiles = await this.db
|
115
119
|
.select({ id: agentsFiles.fileId })
|
116
120
|
.from(agentsFiles)
|
117
121
|
.where(
|
@@ -128,7 +132,7 @@ export class AgentModel {
|
|
128
132
|
|
129
133
|
if (needToInsertFileIds.length === 0) return;
|
130
134
|
|
131
|
-
return
|
135
|
+
return this.db
|
132
136
|
.insert(agentsFiles)
|
133
137
|
.values(
|
134
138
|
needToInsertFileIds.map((fileId) => ({ agentId, enabled, fileId, userId: this.userId })),
|
@@ -137,7 +141,7 @@ export class AgentModel {
|
|
137
141
|
};
|
138
142
|
|
139
143
|
deleteAgentFile = async (agentId: string, fileId: string) => {
|
140
|
-
return
|
144
|
+
return this.db
|
141
145
|
.delete(agentsFiles)
|
142
146
|
.where(
|
143
147
|
and(
|
@@ -150,7 +154,7 @@ export class AgentModel {
|
|
150
154
|
};
|
151
155
|
|
152
156
|
toggleFile = async (agentId: string, fileId: string, enabled?: boolean) => {
|
153
|
-
return
|
157
|
+
return this.db
|
154
158
|
.update(agentsFiles)
|
155
159
|
.set({ enabled })
|
156
160
|
.where(
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import { eq, inArray, lt } from 'drizzle-orm';
|
2
2
|
import { and } from 'drizzle-orm/expressions';
|
3
3
|
|
4
|
-
import {
|
4
|
+
import { LobeChatDatabase } from '@/database/type';
|
5
5
|
import {
|
6
6
|
AsyncTaskError,
|
7
7
|
AsyncTaskErrorType,
|
@@ -16,13 +16,15 @@ export const ASYNC_TASK_TIMEOUT = 298 * 1000;
|
|
16
16
|
|
17
17
|
export class AsyncTaskModel {
|
18
18
|
private userId: string;
|
19
|
+
private db: LobeChatDatabase;
|
19
20
|
|
20
|
-
constructor(userId: string) {
|
21
|
+
constructor(db: LobeChatDatabase, userId: string) {
|
21
22
|
this.userId = userId;
|
23
|
+
this.db = db;
|
22
24
|
}
|
23
25
|
|
24
26
|
create = async (params: Pick<NewAsyncTaskItem, 'type' | 'status'>): Promise<string> => {
|
25
|
-
const data = await
|
27
|
+
const data = await this.db
|
26
28
|
.insert(asyncTasks)
|
27
29
|
.values({ ...params, userId: this.userId })
|
28
30
|
.returning();
|
@@ -31,17 +33,17 @@ export class AsyncTaskModel {
|
|
31
33
|
};
|
32
34
|
|
33
35
|
delete = async (id: string) => {
|
34
|
-
return
|
36
|
+
return this.db
|
35
37
|
.delete(asyncTasks)
|
36
38
|
.where(and(eq(asyncTasks.id, id), eq(asyncTasks.userId, this.userId)));
|
37
39
|
};
|
38
40
|
|
39
41
|
findById = async (id: string) => {
|
40
|
-
return
|
42
|
+
return this.db.query.asyncTasks.findFirst({ where: and(eq(asyncTasks.id, id)) });
|
41
43
|
};
|
42
44
|
|
43
45
|
update(taskId: string, value: Partial<AsyncTaskSelectItem>) {
|
44
|
-
return
|
46
|
+
return this.db
|
45
47
|
.update(asyncTasks)
|
46
48
|
.set({ ...value, updatedAt: new Date() })
|
47
49
|
.where(and(eq(asyncTasks.id, taskId)));
|
@@ -52,7 +54,7 @@ export class AsyncTaskModel {
|
|
52
54
|
|
53
55
|
if (taskIds.length > 0) {
|
54
56
|
await this.checkTimeoutTasks(taskIds);
|
55
|
-
chunkTasks = await
|
57
|
+
chunkTasks = await this.db.query.asyncTasks.findMany({
|
56
58
|
where: and(inArray(asyncTasks.id, taskIds), eq(asyncTasks.type, type)),
|
57
59
|
});
|
58
60
|
}
|
@@ -64,7 +66,7 @@ export class AsyncTaskModel {
|
|
64
66
|
* make the task status to be `error` if the task is not finished in 20 seconds
|
65
67
|
*/
|
66
68
|
async checkTimeoutTasks(ids: string[]) {
|
67
|
-
const tasks = await
|
69
|
+
const tasks = await this.db
|
68
70
|
.select({ id: asyncTasks.id })
|
69
71
|
.from(asyncTasks)
|
70
72
|
.where(
|
@@ -76,7 +78,7 @@ export class AsyncTaskModel {
|
|
76
78
|
);
|
77
79
|
|
78
80
|
if (tasks.length > 0) {
|
79
|
-
await
|
81
|
+
await this.db
|
80
82
|
.update(asyncTasks)
|
81
83
|
.set({
|
82
84
|
error: new AsyncTaskError(
|
@@ -2,7 +2,7 @@ import { asc, cosineDistance, count, eq, inArray, sql } from 'drizzle-orm';
|
|
2
2
|
import { and, desc, isNull } from 'drizzle-orm/expressions';
|
3
3
|
import { chunk } from 'lodash-es';
|
4
4
|
|
5
|
-
import {
|
5
|
+
import { LobeChatDatabase } from '@/database/type';
|
6
6
|
import { ChunkMetadata, FileChunk } from '@/types/chunk';
|
7
7
|
|
8
8
|
import {
|
@@ -18,12 +18,17 @@ import {
|
|
18
18
|
export class ChunkModel {
|
19
19
|
private userId: string;
|
20
20
|
|
21
|
-
|
21
|
+
private db: LobeChatDatabase;
|
22
|
+
|
23
|
+
constructor(db: LobeChatDatabase, userId: string) {
|
22
24
|
this.userId = userId;
|
25
|
+
this.db = db;
|
23
26
|
}
|
24
27
|
|
25
28
|
bulkCreate = async (params: NewChunkItem[], fileId: string) => {
|
26
|
-
return
|
29
|
+
return this.db.transaction(async (trx) => {
|
30
|
+
if (params.length === 0) return [];
|
31
|
+
|
27
32
|
const result = await trx.insert(chunks).values(params).returning();
|
28
33
|
|
29
34
|
const fileChunksData = result.map((chunk) => ({ chunkId: chunk.id, fileId }));
|
@@ -37,15 +42,15 @@ export class ChunkModel {
|
|
37
42
|
};
|
38
43
|
|
39
44
|
bulkCreateUnstructuredChunks = async (params: NewUnstructuredChunkItem[]) => {
|
40
|
-
return
|
45
|
+
return this.db.insert(unstructuredChunks).values(params);
|
41
46
|
};
|
42
47
|
|
43
48
|
delete = async (id: string) => {
|
44
|
-
return
|
49
|
+
return this.db.delete(chunks).where(and(eq(chunks.id, id), eq(chunks.userId, this.userId)));
|
45
50
|
};
|
46
51
|
|
47
52
|
deleteOrphanChunks = async () => {
|
48
|
-
const orphanedChunks = await
|
53
|
+
const orphanedChunks = await this.db
|
49
54
|
.select({ chunkId: chunks.id })
|
50
55
|
.from(chunks)
|
51
56
|
.leftJoin(fileChunks, eq(chunks.id, fileChunks.chunkId))
|
@@ -56,7 +61,7 @@ export class ChunkModel {
|
|
56
61
|
|
57
62
|
const list = chunk(ids, 500);
|
58
63
|
|
59
|
-
await
|
64
|
+
await this.db.transaction(async (trx) => {
|
60
65
|
await Promise.all(
|
61
66
|
list.map(async (chunkIds) => {
|
62
67
|
await trx.delete(chunks).where(inArray(chunks.id, chunkIds));
|
@@ -66,13 +71,13 @@ export class ChunkModel {
|
|
66
71
|
};
|
67
72
|
|
68
73
|
findById = async (id: string) => {
|
69
|
-
return
|
74
|
+
return this.db.query.chunks.findFirst({
|
70
75
|
where: and(eq(chunks.id, id)),
|
71
76
|
});
|
72
77
|
};
|
73
78
|
|
74
79
|
async findByFileId(id: string, page = 0) {
|
75
|
-
const data = await
|
80
|
+
const data = await this.db
|
76
81
|
.select({
|
77
82
|
abstract: chunks.abstract,
|
78
83
|
createdAt: chunks.createdAt,
|
@@ -98,7 +103,7 @@ export class ChunkModel {
|
|
98
103
|
}
|
99
104
|
|
100
105
|
async getChunksTextByFileId(id: string): Promise<{ id: string; text: string }[]> {
|
101
|
-
const data = await
|
106
|
+
const data = await this.db
|
102
107
|
.select()
|
103
108
|
.from(chunks)
|
104
109
|
.innerJoin(fileChunks, eq(chunks.id, fileChunks.chunkId))
|
@@ -113,7 +118,7 @@ export class ChunkModel {
|
|
113
118
|
async countByFileIds(ids: string[]) {
|
114
119
|
if (ids.length === 0) return [];
|
115
120
|
|
116
|
-
return
|
121
|
+
return this.db
|
117
122
|
.select({
|
118
123
|
count: count(fileChunks.chunkId),
|
119
124
|
id: fileChunks.fileId,
|
@@ -124,7 +129,7 @@ export class ChunkModel {
|
|
124
129
|
}
|
125
130
|
|
126
131
|
async countByFileId(ids: string) {
|
127
|
-
const data = await
|
132
|
+
const data = await this.db
|
128
133
|
.select({
|
129
134
|
count: count(fileChunks.chunkId),
|
130
135
|
id: fileChunks.fileId,
|
@@ -146,7 +151,7 @@ export class ChunkModel {
|
|
146
151
|
}) {
|
147
152
|
const similarity = sql<number>`1 - (${cosineDistance(embeddings.embeddings, embedding)})`;
|
148
153
|
|
149
|
-
const data = await
|
154
|
+
const data = await this.db
|
150
155
|
.select({
|
151
156
|
fileId: fileChunks.fileId,
|
152
157
|
fileName: files.name,
|
@@ -185,7 +190,7 @@ export class ChunkModel {
|
|
185
190
|
|
186
191
|
if (!hasFiles) return [];
|
187
192
|
|
188
|
-
const result = await
|
193
|
+
const result = await this.db
|
189
194
|
.select({
|
190
195
|
fileId: files.id,
|
191
196
|
fileName: files.name,
|
@@ -1,19 +1,21 @@
|
|
1
1
|
import { count, eq } from 'drizzle-orm';
|
2
2
|
import { and } from 'drizzle-orm/expressions';
|
3
3
|
|
4
|
-
import {
|
4
|
+
import { LobeChatDatabase } from '@/database/type';
|
5
5
|
|
6
6
|
import { NewEmbeddingsItem, embeddings } from '../../schemas';
|
7
7
|
|
8
8
|
export class EmbeddingModel {
|
9
9
|
private userId: string;
|
10
|
+
private db: LobeChatDatabase;
|
10
11
|
|
11
|
-
constructor(userId: string) {
|
12
|
+
constructor(db: LobeChatDatabase, userId: string) {
|
12
13
|
this.userId = userId;
|
14
|
+
this.db = db;
|
13
15
|
}
|
14
16
|
|
15
17
|
create = async (value: Omit<NewEmbeddingsItem, 'userId'>) => {
|
16
|
-
const [item] = await
|
18
|
+
const [item] = await this.db
|
17
19
|
.insert(embeddings)
|
18
20
|
.values({ ...value, userId: this.userId })
|
19
21
|
.returning();
|
@@ -22,7 +24,7 @@ export class EmbeddingModel {
|
|
22
24
|
};
|
23
25
|
|
24
26
|
bulkCreate = async (values: Omit<NewEmbeddingsItem, 'userId'>[]) => {
|
25
|
-
return
|
27
|
+
return this.db
|
26
28
|
.insert(embeddings)
|
27
29
|
.values(values.map((item) => ({ ...item, userId: this.userId })))
|
28
30
|
.onConflictDoNothing({
|
@@ -31,25 +33,25 @@ export class EmbeddingModel {
|
|
31
33
|
};
|
32
34
|
|
33
35
|
delete = async (id: string) => {
|
34
|
-
return
|
36
|
+
return this.db
|
35
37
|
.delete(embeddings)
|
36
38
|
.where(and(eq(embeddings.id, id), eq(embeddings.userId, this.userId)));
|
37
39
|
};
|
38
40
|
|
39
41
|
query = async () => {
|
40
|
-
return
|
42
|
+
return this.db.query.embeddings.findMany({
|
41
43
|
where: eq(embeddings.userId, this.userId),
|
42
44
|
});
|
43
45
|
};
|
44
46
|
|
45
47
|
findById = async (id: string) => {
|
46
|
-
return
|
48
|
+
return this.db.query.embeddings.findFirst({
|
47
49
|
where: and(eq(embeddings.id, id), eq(embeddings.userId, this.userId)),
|
48
50
|
});
|
49
51
|
};
|
50
52
|
|
51
53
|
countUsage = async () => {
|
52
|
-
const result = await
|
54
|
+
const result = await this.db
|
53
55
|
.select({
|
54
56
|
count: count(),
|
55
57
|
})
|