@lobehub/chat 1.21.11 → 1.21.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. package/CHANGELOG.md +50 -0
  2. package/package.json +2 -2
  3. package/src/app/(backend)/api/chat/[provider]/route.test.ts +6 -2
  4. package/src/app/(backend)/middleware/auth/index.test.ts +5 -2
  5. package/src/app/(backend)/middleware/auth/index.ts +2 -1
  6. package/src/app/(backend)/middleware/auth/utils.test.ts +1 -38
  7. package/src/app/(backend)/middleware/auth/utils.ts +1 -33
  8. package/src/app/(backend)/webapi/plugin/gateway/route.ts +1 -1
  9. package/src/libs/agent-runtime/AgentRuntime.ts +1 -1
  10. package/src/libs/agent-runtime/google/index.ts +2 -2
  11. package/src/libs/agent-runtime/utils/streams/anthropic.ts +2 -8
  12. package/src/libs/agent-runtime/utils/streams/azureOpenai.ts +2 -8
  13. package/src/libs/agent-runtime/utils/streams/google-ai.ts +1 -12
  14. package/src/libs/agent-runtime/utils/streams/ollama.ts +2 -8
  15. package/src/libs/agent-runtime/utils/streams/openai.ts +2 -8
  16. package/src/libs/agent-runtime/utils/streams/protocol.ts +7 -0
  17. package/src/libs/agent-runtime/utils/streams/qwen.ts +2 -3
  18. package/src/libs/agent-runtime/utils/streams/wenxin.test.ts +7 -3
  19. package/src/libs/agent-runtime/utils/streams/wenxin.ts +0 -8
  20. package/src/libs/agent-runtime/wenxin/index.ts +3 -2
  21. package/src/libs/agent-runtime/zhipu/index.test.ts +7 -24
  22. package/src/libs/agent-runtime/zhipu/index.ts +21 -99
  23. package/src/libs/trpc/middleware/jwtPayload.test.ts +1 -1
  24. package/src/libs/trpc/middleware/jwtPayload.ts +1 -1
  25. package/src/libs/trpc/middleware/keyVaults.ts +1 -1
  26. package/src/utils/server/jwt.test.ts +62 -0
  27. package/src/utils/server/jwt.ts +32 -0
package/CHANGELOG.md CHANGED
@@ -2,6 +2,56 @@
2
2
 
3
3
  # Changelog
4
4
 
5
+ ### [Version 1.21.13](https://github.com/lobehub/lobe-chat/compare/v1.21.12...v1.21.13)
6
+
7
+ <sup>Released on **2024-10-11**</sup>
8
+
9
+ #### ♻ Code Refactoring
10
+
11
+ - **misc**: Refactor agent runtime implement of stream and ZHIPU provider.
12
+
13
+ <br/>
14
+
15
+ <details>
16
+ <summary><kbd>Improvements and Fixes</kbd></summary>
17
+
18
+ #### Code refactoring
19
+
20
+ - **misc**: Refactor agent runtime implement of stream and ZHIPU provider, closes [#4323](https://github.com/lobehub/lobe-chat/issues/4323) ([59661a1](https://github.com/lobehub/lobe-chat/commit/59661a1))
21
+
22
+ </details>
23
+
24
+ <div align="right">
25
+
26
+ [![](https://img.shields.io/badge/-BACK_TO_TOP-151515?style=flat-square)](#readme-top)
27
+
28
+ </div>
29
+
30
+ ### [Version 1.21.12](https://github.com/lobehub/lobe-chat/compare/v1.21.11...v1.21.12)
31
+
32
+ <sup>Released on **2024-10-11**</sup>
33
+
34
+ #### ♻ Code Refactoring
35
+
36
+ - **misc**: Refactor the jwt code.
37
+
38
+ <br/>
39
+
40
+ <details>
41
+ <summary><kbd>Improvements and Fixes</kbd></summary>
42
+
43
+ #### Code refactoring
44
+
45
+ - **misc**: Refactor the jwt code, closes [#4322](https://github.com/lobehub/lobe-chat/issues/4322) ([b7258b9](https://github.com/lobehub/lobe-chat/commit/b7258b9))
46
+
47
+ </details>
48
+
49
+ <div align="right">
50
+
51
+ [![](https://img.shields.io/badge/-BACK_TO_TOP-151515?style=flat-square)](#readme-top)
52
+
53
+ </div>
54
+
5
55
  ### [Version 1.21.11](https://github.com/lobehub/lobe-chat/compare/v1.21.10...v1.21.11)
6
56
 
7
57
  <sup>Released on **2024-10-11**</sup>
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@lobehub/chat",
3
- "version": "1.21.11",
3
+ "version": "1.21.13",
4
4
  "description": "Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.",
5
5
  "keywords": [
6
6
  "framework",
@@ -101,7 +101,7 @@
101
101
  "dependencies": {
102
102
  "@ant-design/icons": "^5.4.0",
103
103
  "@ant-design/pro-components": "^2.7.10",
104
- "@anthropic-ai/sdk": "^0.27.0",
104
+ "@anthropic-ai/sdk": "^0.29.0",
105
105
  "@auth/core": "^0.34.2",
106
106
  "@aws-sdk/client-bedrock-runtime": "^3.637.0",
107
107
  "@aws-sdk/client-s3": "^3.637.0",
@@ -2,10 +2,11 @@
2
2
  import { getAuth } from '@clerk/nextjs/server';
3
3
  import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
4
4
 
5
- import { checkAuthMethod, getJWTPayload } from '@/app/(backend)/middleware/auth/utils';
5
+ import { checkAuthMethod } from '@/app/(backend)/middleware/auth/utils';
6
6
  import { LOBE_CHAT_AUTH_HEADER, OAUTH_AUTHORIZED } from '@/const/auth';
7
7
  import { AgentRuntime, LobeRuntimeAI } from '@/libs/agent-runtime';
8
8
  import { ChatErrorType } from '@/types/fetch';
9
+ import { getJWTPayload } from '@/utils/server/jwt';
9
10
 
10
11
  import { POST } from './route';
11
12
 
@@ -14,10 +15,13 @@ vi.mock('@clerk/nextjs/server', () => ({
14
15
  }));
15
16
 
16
17
  vi.mock('@/app/(backend)/middleware/auth/utils', () => ({
17
- getJWTPayload: vi.fn(),
18
18
  checkAuthMethod: vi.fn(),
19
19
  }));
20
20
 
21
+ vi.mock('@/utils/server/jwt', () => ({
22
+ getJWTPayload: vi.fn(),
23
+ }));
24
+
21
25
  // 定义一个变量来存储 enableAuth 的值
22
26
  let enableClerk = false;
23
27
 
@@ -1,12 +1,12 @@
1
- import { getAuth } from '@clerk/nextjs/server';
2
1
  import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
3
2
 
4
3
  import { AgentRuntimeError } from '@/libs/agent-runtime';
5
4
  import { ChatErrorType } from '@/types/fetch';
6
5
  import { createErrorResponse } from '@/utils/errorResponse';
6
+ import { getJWTPayload } from '@/utils/server/jwt';
7
7
 
8
8
  import { RequestHandler, checkAuth } from './index';
9
- import { checkAuthMethod, getJWTPayload } from './utils';
9
+ import { checkAuthMethod } from './utils';
10
10
 
11
11
  vi.mock('@clerk/nextjs/server', () => ({
12
12
  getAuth: vi.fn(),
@@ -18,6 +18,9 @@ vi.mock('@/utils/errorResponse', () => ({
18
18
 
19
19
  vi.mock('./utils', () => ({
20
20
  checkAuthMethod: vi.fn(),
21
+ }));
22
+
23
+ vi.mock('@/utils/server/jwt', () => ({
21
24
  getJWTPayload: vi.fn(),
22
25
  }));
23
26
 
@@ -6,8 +6,9 @@ import { JWTPayload, LOBE_CHAT_AUTH_HEADER, OAUTH_AUTHORIZED, enableClerk } from
6
6
  import { AgentRuntime, AgentRuntimeError, ChatCompletionErrorPayload } from '@/libs/agent-runtime';
7
7
  import { ChatErrorType } from '@/types/fetch';
8
8
  import { createErrorResponse } from '@/utils/errorResponse';
9
+ import { getJWTPayload } from '@/utils/server/jwt';
9
10
 
10
- import { checkAuthMethod, getJWTPayload } from './utils';
11
+ import { checkAuthMethod } from './utils';
11
12
 
12
13
  type CreateRuntime = (jwtPayload: JWTPayload) => AgentRuntime;
13
14
  type RequestOptions = { createRuntime?: CreateRuntime; params: { provider: string } };
@@ -2,9 +2,8 @@ import { type AuthObject } from '@clerk/backend';
2
2
  import { beforeEach, describe, expect, it, vi } from 'vitest';
3
3
 
4
4
  import { getAppConfig } from '@/config/app';
5
- import { NON_HTTP_PREFIX } from '@/const/auth';
6
5
 
7
- import { checkAuthMethod, getJWTPayload } from './utils';
6
+ import { checkAuthMethod } from './utils';
8
7
 
9
8
  let enableClerkMock = false;
10
9
  let enableNextAuthMock = false;
@@ -27,42 +26,6 @@ vi.mock('@/config/app', () => ({
27
26
  getAppConfig: vi.fn(),
28
27
  }));
29
28
 
30
- describe('getJWTPayload', () => {
31
- it('should parse JWT payload for non-HTTPS token', async () => {
32
- const token = `${NON_HTTP_PREFIX}.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ`;
33
- const payload = await getJWTPayload(token);
34
- expect(payload).toEqual({
35
- sub: '1234567890',
36
- name: 'John Doe',
37
- iat: 1516239022,
38
- });
39
- });
40
-
41
- it('should verify and parse JWT payload for HTTPS token', async () => {
42
- const token =
43
- 'eyJhbGciOiJIUzI1NiJ9.eyJhY2Nlc3NDb2RlIjoiIiwidXNlcklkIjoiMDAxMzYyYzMtNDhjNS00NjM1LWJkM2ItODM3YmZmZjU4ZmMwIiwiYXBpS2V5IjoiYWJjIiwiZW5kcG9pbnQiOiJhYmMiLCJpYXQiOjE3MTY4MDIyMjUsImV4cCI6MTAwMDAwMDAwMDE3MTY4MDIwMDB9.FF0FxsE8Cajs-_hv5GD0TNUDwvekAkI9l_LL_IOPdGQ';
44
- const payload = await getJWTPayload(token);
45
- expect(payload).toEqual({
46
- accessCode: '',
47
- apiKey: 'abc',
48
- endpoint: 'abc',
49
- exp: 10000000001716802000,
50
- iat: 1716802225,
51
- userId: '001362c3-48c5-4635-bd3b-837bfff58fc0',
52
- });
53
- });
54
-
55
- it('should not verify success and parse JWT payload for dated token', async () => {
56
- const token =
57
- 'eyJhbGciOiJIUzI1NiJ9.eyJhY2Nlc3NDb2RlIjoiIiwidXNlcklkIjoiYWY3M2JhODktZjFhMy00YjliLWEwM2QtZGViZmZlMzE4NmQxIiwiYXBpS2V5IjoiYWJjIiwiZW5kcG9pbnQiOiJhYmMiLCJpYXQiOjE3MTY3OTk5ODAsImV4cCI6MTcxNjgwMDA4MH0.8AGFsLcwyrQG82kVUYOGFXHIwihm2n16ctyArKW9100';
58
- try {
59
- await getJWTPayload(token);
60
- } catch (e) {
61
- expect(e).toEqual(new TypeError('"exp" claim timestamp check failed'));
62
- }
63
- });
64
- });
65
-
66
29
  describe('checkAuthMethod', () => {
67
30
  beforeEach(() => {
68
31
  vi.mocked(getAppConfig).mockReturnValue({
@@ -1,42 +1,10 @@
1
1
  import { type AuthObject } from '@clerk/backend';
2
- import { importJWK, jwtVerify } from 'jose';
3
2
 
4
3
  import { getAppConfig } from '@/config/app';
5
- import {
6
- JWTPayload,
7
- JWT_SECRET_KEY,
8
- NON_HTTP_PREFIX,
9
- enableClerk,
10
- enableNextAuth,
11
- } from '@/const/auth';
4
+ import { enableClerk, enableNextAuth } from '@/const/auth';
12
5
  import { AgentRuntimeError } from '@/libs/agent-runtime';
13
6
  import { ChatErrorType } from '@/types/fetch';
14
7
 
15
- export const getJWTPayload = async (token: string): Promise<JWTPayload> => {
16
- //如果是 HTTP 协议发起的请求,直接解析 token
17
- // 这是一个非常 hack 的解决方案,未来要找更好的解决方案来处理这个问题
18
- // refs: https://github.com/lobehub/lobe-chat/pull/1238
19
- if (token.startsWith(NON_HTTP_PREFIX)) {
20
- const jwtParts = token.split('.');
21
-
22
- const payload = jwtParts[1];
23
-
24
- return JSON.parse(atob(payload));
25
- }
26
-
27
- const encoder = new TextEncoder();
28
- const secretKey = await crypto.subtle.digest('SHA-256', encoder.encode(JWT_SECRET_KEY));
29
-
30
- const jwkSecretKey = await importJWK(
31
- { k: Buffer.from(secretKey).toString('base64'), kty: 'oct' },
32
- 'HS256',
33
- );
34
-
35
- const { payload } = await jwtVerify(token, jwkSecretKey);
36
-
37
- return payload as JWTPayload;
38
- };
39
-
40
8
  interface CheckAuthParams {
41
9
  accessCode?: string;
42
10
  apiKey?: string;
@@ -1,7 +1,6 @@
1
1
  import { PluginRequestPayload } from '@lobehub/chat-plugin-sdk';
2
2
  import { createGatewayOnEdgeRuntime } from '@lobehub/chat-plugins-gateway';
3
3
 
4
- import { getJWTPayload } from '@/app/(backend)/middleware/auth/utils';
5
4
  import { getAppConfig } from '@/config/app';
6
5
  import { LOBE_CHAT_AUTH_HEADER, OAUTH_AUTHORIZED, enableNextAuth } from '@/const/auth';
7
6
  import { LOBE_CHAT_TRACE_ID, TraceNameMap } from '@/const/trace';
@@ -9,6 +8,7 @@ import { AgentRuntimeError } from '@/libs/agent-runtime';
9
8
  import { TraceClient } from '@/libs/traces';
10
9
  import { ChatErrorType, ErrorType } from '@/types/fetch';
11
10
  import { createErrorResponse } from '@/utils/errorResponse';
11
+ import { getJWTPayload } from '@/utils/server/jwt';
12
12
  import { getTracePayload } from '@/utils/trace';
13
13
 
14
14
  import { parserPluginSettings } from './settings';
@@ -174,7 +174,7 @@ class AgentRuntime {
174
174
  }
175
175
 
176
176
  case ModelProvider.ZhiPu: {
177
- runtimeModel = await LobeZhipuAI.fromAPIKey(params.zhipu);
177
+ runtimeModel = new LobeZhipuAI(params.zhipu);
178
178
  break;
179
179
  }
180
180
 
@@ -27,7 +27,7 @@ import { ModelProvider } from '../types/type';
27
27
  import { AgentRuntimeError } from '../utils/createError';
28
28
  import { debugStream } from '../utils/debugStream';
29
29
  import { StreamingResponse } from '../utils/response';
30
- import { GoogleGenerativeAIStream, googleGenAIResultToStream } from '../utils/streams';
30
+ import { GoogleGenerativeAIStream, convertIterableToStream } from '../utils/streams';
31
31
  import { parseDataUri } from '../utils/uriParser';
32
32
 
33
33
  enum HarmCategory {
@@ -97,7 +97,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
97
97
  tools: this.buildGoogleTools(payload.tools),
98
98
  });
99
99
 
100
- const googleStream = googleGenAIResultToStream(geminiStreamResult);
100
+ const googleStream = convertIterableToStream(geminiStreamResult.stream);
101
101
  const [prod, useForDebug] = googleStream.tee();
102
102
 
103
103
  if (process.env.DEBUG_GOOGLE_CHAT_COMPLETION === '1') {
@@ -1,6 +1,5 @@
1
1
  import Anthropic from '@anthropic-ai/sdk';
2
2
  import type { Stream } from '@anthropic-ai/sdk/streaming';
3
- import { readableFromAsyncIterable } from 'ai';
4
3
 
5
4
  import { ChatStreamCallbacks } from '../../types';
6
5
  import {
@@ -8,6 +7,7 @@ import {
8
7
  StreamProtocolToolCallChunk,
9
8
  StreamStack,
10
9
  StreamToolCallChunkData,
10
+ convertIterableToStream,
11
11
  createCallbacksTransformer,
12
12
  createSSEProtocolTransformer,
13
13
  } from './protocol';
@@ -96,12 +96,6 @@ export const transformAnthropicStream = (
96
96
  }
97
97
  };
98
98
 
99
- const chatStreamable = async function* (stream: AsyncIterable<Anthropic.MessageStreamEvent>) {
100
- for await (const response of stream) {
101
- yield response;
102
- }
103
- };
104
-
105
99
  export const AnthropicStream = (
106
100
  stream: Stream<Anthropic.MessageStreamEvent> | ReadableStream,
107
101
  callbacks?: ChatStreamCallbacks,
@@ -109,7 +103,7 @@ export const AnthropicStream = (
109
103
  const streamStack: StreamStack = { id: '' };
110
104
 
111
105
  const readableStream =
112
- stream instanceof ReadableStream ? stream : readableFromAsyncIterable(chatStreamable(stream));
106
+ stream instanceof ReadableStream ? stream : convertIterableToStream(stream);
113
107
 
114
108
  return readableStream
115
109
  .pipeThrough(createSSEProtocolTransformer(transformAnthropicStream, streamStack))
@@ -1,5 +1,4 @@
1
1
  import { ChatCompletions, ChatCompletionsFunctionToolCall } from '@azure/openai';
2
- import { readableFromAsyncIterable } from 'ai';
3
2
  import OpenAI from 'openai';
4
3
  import type { Stream } from 'openai/streaming';
5
4
 
@@ -9,6 +8,7 @@ import {
9
8
  StreamProtocolToolCallChunk,
10
9
  StreamStack,
11
10
  StreamToolCallChunkData,
11
+ convertIterableToStream,
12
12
  createCallbacksTransformer,
13
13
  createSSEProtocolTransformer,
14
14
  } from './protocol';
@@ -69,19 +69,13 @@ const transformOpenAIStream = (chunk: ChatCompletions, stack: StreamStack): Stre
69
69
  };
70
70
  };
71
71
 
72
- const chatStreamable = async function* (stream: AsyncIterable<OpenAI.ChatCompletionChunk>) {
73
- for await (const response of stream) {
74
- yield response;
75
- }
76
- };
77
-
78
72
  export const AzureOpenAIStream = (
79
73
  stream: Stream<OpenAI.ChatCompletionChunk> | ReadableStream,
80
74
  callbacks?: ChatStreamCallbacks,
81
75
  ) => {
82
76
  const stack: StreamStack = { id: '' };
83
77
  const readableStream =
84
- stream instanceof ReadableStream ? stream : readableFromAsyncIterable(chatStreamable(stream));
78
+ stream instanceof ReadableStream ? stream : convertIterableToStream(stream);
85
79
 
86
80
  return readableStream
87
81
  .pipeThrough(createSSEProtocolTransformer(transformOpenAIStream, stack))
@@ -1,8 +1,4 @@
1
- import {
2
- EnhancedGenerateContentResponse,
3
- GenerateContentStreamResult,
4
- } from '@google/generative-ai';
5
- import { readableFromAsyncIterable } from 'ai';
1
+ import { EnhancedGenerateContentResponse } from '@google/generative-ai';
6
2
 
7
3
  import { nanoid } from '@/utils/uuid';
8
4
 
@@ -11,7 +7,6 @@ import {
11
7
  StreamProtocolChunk,
12
8
  StreamStack,
13
9
  StreamToolCallChunkData,
14
- chatStreamable,
15
10
  createCallbacksTransformer,
16
11
  createSSEProtocolTransformer,
17
12
  generateToolCallId,
@@ -50,12 +45,6 @@ const transformGoogleGenerativeAIStream = (
50
45
  };
51
46
  };
52
47
 
53
- // only use for debug
54
- export const googleGenAIResultToStream = (stream: GenerateContentStreamResult) => {
55
- // make the response to the streamable format
56
- return readableFromAsyncIterable(chatStreamable(stream.stream));
57
- };
58
-
59
48
  export const GoogleGenerativeAIStream = (
60
49
  rawStream: ReadableStream<EnhancedGenerateContentResponse>,
61
50
  callbacks?: ChatStreamCallbacks,
@@ -1,4 +1,3 @@
1
- import { readableFromAsyncIterable } from 'ai';
2
1
  import { ChatResponse } from 'ollama/browser';
3
2
 
4
3
  import { ChatStreamCallbacks } from '@/libs/agent-runtime';
@@ -7,6 +6,7 @@ import { nanoid } from '@/utils/uuid';
7
6
  import {
8
7
  StreamProtocolChunk,
9
8
  StreamStack,
9
+ convertIterableToStream,
10
10
  createCallbacksTransformer,
11
11
  createSSEProtocolTransformer,
12
12
  } from './protocol';
@@ -20,19 +20,13 @@ const transformOllamaStream = (chunk: ChatResponse, stack: StreamStack): StreamP
20
20
  return { data: chunk.message.content, id: stack.id, type: 'text' };
21
21
  };
22
22
 
23
- const chatStreamable = async function* (stream: AsyncIterable<ChatResponse>) {
24
- for await (const response of stream) {
25
- yield response;
26
- }
27
- };
28
-
29
23
  export const OllamaStream = (
30
24
  res: AsyncIterable<ChatResponse>,
31
25
  cb?: ChatStreamCallbacks,
32
26
  ): ReadableStream<string> => {
33
27
  const streamStack: StreamStack = { id: 'chat_' + nanoid() };
34
28
 
35
- return readableFromAsyncIterable(chatStreamable(res))
29
+ return convertIterableToStream(res)
36
30
  .pipeThrough(createSSEProtocolTransformer(transformOllamaStream, streamStack))
37
31
  .pipeThrough(createCallbacksTransformer(cb));
38
32
  };
@@ -1,4 +1,3 @@
1
- import { readableFromAsyncIterable } from 'ai';
2
1
  import OpenAI from 'openai';
3
2
  import type { Stream } from 'openai/streaming';
4
3
 
@@ -10,6 +9,7 @@ import {
10
9
  StreamProtocolToolCallChunk,
11
10
  StreamStack,
12
11
  StreamToolCallChunkData,
12
+ convertIterableToStream,
13
13
  createCallbacksTransformer,
14
14
  createSSEProtocolTransformer,
15
15
  generateToolCallId,
@@ -105,12 +105,6 @@ export const transformOpenAIStream = (
105
105
  }
106
106
  };
107
107
 
108
- const chatStreamable = async function* (stream: AsyncIterable<OpenAI.ChatCompletionChunk>) {
109
- for await (const response of stream) {
110
- yield response;
111
- }
112
- };
113
-
114
108
  export const OpenAIStream = (
115
109
  stream: Stream<OpenAI.ChatCompletionChunk> | ReadableStream,
116
110
  callbacks?: ChatStreamCallbacks,
@@ -118,7 +112,7 @@ export const OpenAIStream = (
118
112
  const streamStack: StreamStack = { id: '' };
119
113
 
120
114
  const readableStream =
121
- stream instanceof ReadableStream ? stream : readableFromAsyncIterable(chatStreamable(stream));
115
+ stream instanceof ReadableStream ? stream : convertIterableToStream(stream);
122
116
 
123
117
  return readableStream
124
118
  .pipeThrough(createSSEProtocolTransformer(transformOpenAIStream, streamStack))
@@ -1,3 +1,5 @@
1
+ import { readableFromAsyncIterable } from 'ai';
2
+
1
3
  import { ChatStreamCallbacks } from '@/libs/agent-runtime';
2
4
 
3
5
  export interface StreamStack {
@@ -42,6 +44,11 @@ export const chatStreamable = async function* <T>(stream: AsyncIterable<T>) {
42
44
  }
43
45
  };
44
46
 
47
+ // make the response to the streamable format
48
+ export const convertIterableToStream = <T>(stream: AsyncIterable<T>) => {
49
+ return readableFromAsyncIterable(chatStreamable(stream));
50
+ };
51
+
45
52
  export const createSSEProtocolTransformer = (
46
53
  transformer: (chunk: any, stack: StreamStack) => StreamProtocolChunk,
47
54
  streamStack?: StreamStack,
@@ -1,4 +1,3 @@
1
- import { readableFromAsyncIterable } from 'ai';
2
1
  import { ChatCompletionContentPartText } from 'ai/prompts';
3
2
  import OpenAI from 'openai';
4
3
  import { ChatCompletionContentPart } from 'openai/resources/index.mjs';
@@ -9,7 +8,7 @@ import {
9
8
  StreamProtocolChunk,
10
9
  StreamProtocolToolCallChunk,
11
10
  StreamToolCallChunkData,
12
- chatStreamable,
11
+ convertIterableToStream,
13
12
  createCallbacksTransformer,
14
13
  createSSEProtocolTransformer,
15
14
  generateToolCallId,
@@ -86,7 +85,7 @@ export const QwenAIStream = (
86
85
  callbacks?: ChatStreamCallbacks,
87
86
  ) => {
88
87
  const readableStream =
89
- stream instanceof ReadableStream ? stream : readableFromAsyncIterable(chatStreamable(stream));
88
+ stream instanceof ReadableStream ? stream : convertIterableToStream(stream);
90
89
 
91
90
  return readableStream
92
91
  .pipeThrough(createSSEProtocolTransformer(transformQwenStream))
@@ -2,8 +2,9 @@ import { describe, expect, it, vi } from 'vitest';
2
2
 
3
3
  import * as uuidModule from '@/utils/uuid';
4
4
 
5
+ import { convertIterableToStream } from '../../utils/streams/protocol';
5
6
  import { ChatResp } from '../../wenxin/type';
6
- import { WenxinResultToStream, WenxinStream } from './wenxin';
7
+ import { WenxinStream } from './wenxin';
7
8
 
8
9
  const dataStream = [
9
10
  {
@@ -95,7 +96,7 @@ describe('WenxinStream', () => {
95
96
  },
96
97
  };
97
98
 
98
- const stream = WenxinResultToStream(mockWenxinStream);
99
+ const stream = convertIterableToStream(mockWenxinStream);
99
100
 
100
101
  const onStartMock = vi.fn();
101
102
  const onTextMock = vi.fn();
@@ -142,7 +143,10 @@ describe('WenxinStream', () => {
142
143
 
143
144
  expect(onStartMock).toHaveBeenCalledTimes(1);
144
145
  expect(onTextMock).toHaveBeenNthCalledWith(1, '"当然可以,"');
145
- expect(onTextMock).toHaveBeenNthCalledWith(2, '"以下是一些建议的自驾游路线,它们涵盖了各种不同的风景和文化体验:\\n\\n1. **西安-敦煌历史文化之旅**:\\n\\n\\n\\t* 路线:西安"');
146
+ expect(onTextMock).toHaveBeenNthCalledWith(
147
+ 2,
148
+ '"以下是一些建议的自驾游路线,它们涵盖了各种不同的风景和文化体验:\\n\\n1. **西安-敦煌历史文化之旅**:\\n\\n\\n\\t* 路线:西安"',
149
+ );
146
150
  expect(onTokenMock).toHaveBeenCalledTimes(6);
147
151
  expect(onCompletionMock).toHaveBeenCalledTimes(1);
148
152
  });
@@ -1,5 +1,3 @@
1
- import { readableFromAsyncIterable } from 'ai';
2
-
3
1
  import { ChatStreamCallbacks } from '@/libs/agent-runtime';
4
2
  import { nanoid } from '@/utils/uuid';
5
3
 
@@ -7,7 +5,6 @@ import { ChatResp } from '../../wenxin/type';
7
5
  import {
8
6
  StreamProtocolChunk,
9
7
  StreamStack,
10
- chatStreamable,
11
8
  createCallbacksTransformer,
12
9
  createSSEProtocolTransformer,
13
10
  } from './protocol';
@@ -29,11 +26,6 @@ const transformERNIEBotStream = (chunk: ChatResp): StreamProtocolChunk => {
29
26
  };
30
27
  };
31
28
 
32
- export const WenxinResultToStream = (stream: AsyncIterable<ChatResp>) => {
33
- // make the response to the streamable format
34
- return readableFromAsyncIterable(chatStreamable(stream));
35
- };
36
-
37
29
  export const WenxinStream = (
38
30
  rawStream: ReadableStream<ChatResp>,
39
31
  callbacks?: ChatStreamCallbacks,
@@ -10,7 +10,8 @@ import { ChatCompetitionOptions, ChatStreamPayload } from '../types';
10
10
  import { AgentRuntimeError } from '../utils/createError';
11
11
  import { debugStream } from '../utils/debugStream';
12
12
  import { StreamingResponse } from '../utils/response';
13
- import { WenxinResultToStream, WenxinStream } from '../utils/streams/wenxin';
13
+ import { convertIterableToStream } from '../utils/streams';
14
+ import { WenxinStream } from '../utils/streams/wenxin';
14
15
  import { ChatResp } from './type';
15
16
 
16
17
  interface ChatErrorCode {
@@ -46,7 +47,7 @@ export class LobeWenxinAI implements LobeRuntimeAI {
46
47
  payload.model,
47
48
  );
48
49
 
49
- const wenxinStream = WenxinResultToStream(result as AsyncIterable<ChatResp>);
50
+ const wenxinStream = convertIterableToStream(result as AsyncIterable<ChatResp>);
50
51
 
51
52
  const [prod, useForDebug] = wenxinStream.tee();
52
53
 
@@ -2,7 +2,7 @@
2
2
  import { OpenAI } from 'openai';
3
3
  import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
4
4
 
5
- import { ChatStreamCallbacks, LobeOpenAI } from '@/libs/agent-runtime';
5
+ import { ChatStreamCallbacks, LobeOpenAI, LobeOpenAICompatibleRuntime } from '@/libs/agent-runtime';
6
6
  import * as debugStreamModule from '@/libs/agent-runtime/utils/debugStream';
7
7
 
8
8
  import * as authTokenModule from './authToken';
@@ -24,28 +24,11 @@ describe('LobeZhipuAI', () => {
24
24
  vi.restoreAllMocks();
25
25
  });
26
26
 
27
- describe('fromAPIKey', () => {
28
- it('should correctly initialize with an API key', async () => {
29
- const lobeZhipuAI = await LobeZhipuAI.fromAPIKey({ apiKey: 'test_api_key' });
30
- expect(lobeZhipuAI).toBeInstanceOf(LobeZhipuAI);
31
- expect(lobeZhipuAI.baseURL).toEqual('https://open.bigmodel.cn/api/paas/v4');
32
- });
33
-
34
- it('should throw an error if API key is invalid', async () => {
35
- vi.spyOn(authTokenModule, 'generateApiToken').mockRejectedValue(new Error('Invalid API Key'));
36
- try {
37
- await LobeZhipuAI.fromAPIKey({ apiKey: 'asd' });
38
- } catch (e) {
39
- expect(e).toEqual({ errorType: invalidErrorType });
40
- }
41
- });
42
- });
43
-
44
27
  describe('chat', () => {
45
- let instance: LobeZhipuAI;
28
+ let instance: LobeOpenAICompatibleRuntime;
46
29
 
47
30
  beforeEach(async () => {
48
- instance = await LobeZhipuAI.fromAPIKey({
31
+ instance = new LobeZhipuAI({
49
32
  apiKey: 'test_api_key',
50
33
  });
51
34
 
@@ -131,9 +114,9 @@ describe('LobeZhipuAI', () => {
131
114
  const calledWithParams = spyOn.mock.calls[0][0];
132
115
 
133
116
  expect(calledWithParams.messages[1].content).toEqual([{ type: 'text', text: 'Hello again' }]);
134
- expect(calledWithParams.temperature).toBeUndefined(); // temperature 0 should be undefined
117
+ expect(calledWithParams.temperature).toBe(0); // temperature 0 should be undefined
135
118
  expect((calledWithParams as any).do_sample).toBeTruthy(); // temperature 0 should be undefined
136
- expect(calledWithParams.top_p).toEqual(0.99); // top_p should be transformed correctly
119
+ expect(calledWithParams.top_p).toEqual(1); // top_p should be transformed correctly
137
120
  });
138
121
 
139
122
  describe('Error', () => {
@@ -175,7 +158,7 @@ describe('LobeZhipuAI', () => {
175
158
 
176
159
  it('should throw AgentRuntimeError with NoOpenAIAPIKey if no apiKey is provided', async () => {
177
160
  try {
178
- await LobeZhipuAI.fromAPIKey({ apiKey: '' });
161
+ new LobeZhipuAI({ apiKey: '' });
179
162
  } catch (e) {
180
163
  expect(e).toEqual({ errorType: invalidErrorType });
181
164
  }
@@ -221,7 +204,7 @@ describe('LobeZhipuAI', () => {
221
204
  };
222
205
  const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {});
223
206
 
224
- instance = await LobeZhipuAI.fromAPIKey({
207
+ instance = new LobeZhipuAI({
225
208
  apiKey: 'test',
226
209
 
227
210
  baseURL: 'https://abc.com/v2',
@@ -1,99 +1,21 @@
1
- import OpenAI, { ClientOptions } from 'openai';
2
-
3
- import { LobeRuntimeAI } from '../BaseAI';
4
- import { AgentRuntimeErrorType } from '../error';
5
- import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types';
6
- import { AgentRuntimeError } from '../utils/createError';
7
- import { debugStream } from '../utils/debugStream';
8
- import { desensitizeUrl } from '../utils/desensitizeUrl';
9
- import { handleOpenAIError } from '../utils/handleOpenAIError';
10
- import { convertOpenAIMessages } from '../utils/openaiHelpers';
11
- import { StreamingResponse } from '../utils/response';
12
- import { OpenAIStream } from '../utils/streams';
13
- import { generateApiToken } from './authToken';
14
-
15
- const DEFAULT_BASE_URL = 'https://open.bigmodel.cn/api/paas/v4';
16
-
17
- export class LobeZhipuAI implements LobeRuntimeAI {
18
- private client: OpenAI;
19
-
20
- baseURL: string;
21
-
22
- constructor(oai: OpenAI) {
23
- this.client = oai;
24
- this.baseURL = this.client.baseURL;
25
- }
26
-
27
- static async fromAPIKey({ apiKey, baseURL = DEFAULT_BASE_URL, ...res }: ClientOptions = {}) {
28
- const invalidZhipuAPIKey = AgentRuntimeError.createError(
29
- AgentRuntimeErrorType.InvalidProviderAPIKey,
30
- );
31
-
32
- if (!apiKey) throw invalidZhipuAPIKey;
33
-
34
- let token: string;
35
-
36
- try {
37
- token = await generateApiToken(apiKey);
38
- } catch {
39
- throw invalidZhipuAPIKey;
40
- }
41
-
42
- const header = { Authorization: `Bearer ${token}` };
43
- const llm = new OpenAI({ apiKey, baseURL, defaultHeaders: header, ...res });
44
-
45
- return new LobeZhipuAI(llm);
46
- }
47
-
48
- async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) {
49
- try {
50
- const params = await this.buildCompletionsParams(payload);
51
-
52
- const response = await this.client.chat.completions.create(
53
- params as unknown as OpenAI.ChatCompletionCreateParamsStreaming,
54
- );
55
-
56
- const [prod, debug] = response.tee();
57
-
58
- if (process.env.DEBUG_ZHIPU_CHAT_COMPLETION === '1') {
59
- debugStream(debug.toReadableStream()).catch(console.error);
60
- }
61
-
62
- return StreamingResponse(OpenAIStream(prod, options?.callback), {
63
- headers: options?.headers,
64
- });
65
- } catch (error) {
66
- const { errorResult, RuntimeError } = handleOpenAIError(error);
67
-
68
- const errorType = RuntimeError || AgentRuntimeErrorType.ProviderBizError;
69
- let desensitizedEndpoint = this.baseURL;
70
-
71
- if (this.baseURL !== DEFAULT_BASE_URL) {
72
- desensitizedEndpoint = desensitizeUrl(this.baseURL);
73
- }
74
- throw AgentRuntimeError.chat({
75
- endpoint: desensitizedEndpoint,
76
- error: errorResult,
77
- errorType,
78
- provider: ModelProvider.ZhiPu,
79
- });
80
- }
81
- }
82
-
83
- private async buildCompletionsParams(payload: ChatStreamPayload) {
84
- const { messages, temperature, top_p, ...params } = payload;
85
-
86
- return {
87
- messages: await convertOpenAIMessages(messages as any),
88
- ...params,
89
- do_sample: temperature === 0,
90
- stream: true,
91
- // 当前的模型侧不支持 top_p=1 和 temperature 为 0
92
- // refs: https://zhipu-ai.feishu.cn/wiki/TUo0w2LT7iswnckmfSEcqTD0ncd
93
- temperature: temperature === 0 ? undefined : temperature,
94
- top_p: top_p === 1 ? 0.99 : top_p,
95
- };
96
- }
97
- }
98
-
99
- export default LobeZhipuAI;
1
+ import OpenAI from 'openai';
2
+
3
+ import { ChatStreamPayload, ModelProvider } from '../types';
4
+ import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory';
5
+
6
+ export const LobeZhipuAI = LobeOpenAICompatibleFactory({
7
+ baseURL: 'https://open.bigmodel.cn/api/paas/v4',
8
+ chatCompletion: {
9
+ handlePayload: ({ temperature, ...payload }: ChatStreamPayload) =>
10
+ ({
11
+ ...payload,
12
+ do_sample: temperature === 0,
13
+ stream: true,
14
+ temperature,
15
+ }) as OpenAI.ChatCompletionCreateParamsStreaming,
16
+ },
17
+ debug: {
18
+ chatCompletion: () => process.env.DEBUG_ZHIPU_CHAT_COMPLETION === '1',
19
+ },
20
+ provider: ModelProvider.ZhiPu,
21
+ });
@@ -2,10 +2,10 @@
2
2
  import { TRPCError } from '@trpc/server';
3
3
  import { beforeEach, describe, expect, it, vi } from 'vitest';
4
4
 
5
- import * as utils from '@/app/(backend)/middleware/auth/utils';
6
5
  import { createCallerFactory } from '@/libs/trpc';
7
6
  import { trpc } from '@/libs/trpc/init';
8
7
  import { AuthContext, createContextInner } from '@/server/context';
8
+ import * as utils from '@/utils/server/jwt';
9
9
 
10
10
  import { jwtPayloadChecker } from './jwtPayload';
11
11
 
@@ -1,7 +1,7 @@
1
1
  import { TRPCError } from '@trpc/server';
2
2
 
3
- import { getJWTPayload } from '@/app/(backend)/middleware/auth/utils';
4
3
  import { trpc } from '@/libs/trpc/init';
4
+ import { getJWTPayload } from '@/utils/server/jwt';
5
5
 
6
6
  export const jwtPayloadChecker = trpc.middleware(async (opts) => {
7
7
  const { ctx } = opts;
@@ -1,7 +1,7 @@
1
1
  import { TRPCError } from '@trpc/server';
2
2
 
3
- import { getJWTPayload } from '@/app/(backend)/middleware/auth/utils';
4
3
  import { trpc } from '@/libs/trpc/init';
4
+ import { getJWTPayload } from '@/utils/server/jwt';
5
5
 
6
6
  export const keyVaults = trpc.middleware(async (opts) => {
7
7
  const { ctx } = opts;
@@ -0,0 +1,62 @@
1
+ import { describe, expect, it, vi } from 'vitest';
2
+
3
+ import { NON_HTTP_PREFIX } from '@/const/auth';
4
+
5
+ import { getJWTPayload } from './jwt';
6
+
7
+ let enableClerkMock = false;
8
+ let enableNextAuthMock = false;
9
+
10
+ vi.mock('@/const/auth', async (importOriginal) => {
11
+ const data = await importOriginal();
12
+
13
+ return {
14
+ ...(data as any),
15
+ get enableClerk() {
16
+ return enableClerkMock;
17
+ },
18
+ get enableNextAuth() {
19
+ return enableNextAuthMock;
20
+ },
21
+ };
22
+ });
23
+
24
+ vi.mock('@/config/app', () => ({
25
+ getAppConfig: vi.fn(),
26
+ }));
27
+
28
+ describe('getJWTPayload', () => {
29
+ it('should parse JWT payload for non-HTTPS token', async () => {
30
+ const token = `${NON_HTTP_PREFIX}.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ`;
31
+ const payload = await getJWTPayload(token);
32
+ expect(payload).toEqual({
33
+ sub: '1234567890',
34
+ name: 'John Doe',
35
+ iat: 1516239022,
36
+ });
37
+ });
38
+
39
+ it('should verify and parse JWT payload for HTTPS token', async () => {
40
+ const token =
41
+ 'eyJhbGciOiJIUzI1NiJ9.eyJhY2Nlc3NDb2RlIjoiIiwidXNlcklkIjoiMDAxMzYyYzMtNDhjNS00NjM1LWJkM2ItODM3YmZmZjU4ZmMwIiwiYXBpS2V5IjoiYWJjIiwiZW5kcG9pbnQiOiJhYmMiLCJpYXQiOjE3MTY4MDIyMjUsImV4cCI6MTAwMDAwMDAwMDE3MTY4MDIwMDB9.FF0FxsE8Cajs-_hv5GD0TNUDwvekAkI9l_LL_IOPdGQ';
42
+ const payload = await getJWTPayload(token);
43
+ expect(payload).toEqual({
44
+ accessCode: '',
45
+ apiKey: 'abc',
46
+ endpoint: 'abc',
47
+ exp: 10000000001716802000,
48
+ iat: 1716802225,
49
+ userId: '001362c3-48c5-4635-bd3b-837bfff58fc0',
50
+ });
51
+ });
52
+
53
+ it('should not verify success and parse JWT payload for dated token', async () => {
54
+ const token =
55
+ 'eyJhbGciOiJIUzI1NiJ9.eyJhY2Nlc3NDb2RlIjoiIiwidXNlcklkIjoiYWY3M2JhODktZjFhMy00YjliLWEwM2QtZGViZmZlMzE4NmQxIiwiYXBpS2V5IjoiYWJjIiwiZW5kcG9pbnQiOiJhYmMiLCJpYXQiOjE3MTY3OTk5ODAsImV4cCI6MTcxNjgwMDA4MH0.8AGFsLcwyrQG82kVUYOGFXHIwihm2n16ctyArKW9100';
56
+ try {
57
+ await getJWTPayload(token);
58
+ } catch (e) {
59
+ expect(e).toEqual(new TypeError('"exp" claim timestamp check failed'));
60
+ }
61
+ });
62
+ });
@@ -0,0 +1,32 @@
1
+ import { importJWK, jwtVerify } from 'jose';
2
+
3
+ import {
4
+ JWTPayload,
5
+ JWT_SECRET_KEY,
6
+ NON_HTTP_PREFIX,
7
+ } from '@/const/auth';
8
+
9
+ export const getJWTPayload = async (token: string): Promise<JWTPayload> => {
10
+ //如果是 HTTP 协议发起的请求,直接解析 token
11
+ // 这是一个非常 hack 的解决方案,未来要找更好的解决方案来处理这个问题
12
+ // refs: https://github.com/lobehub/lobe-chat/pull/1238
13
+ if (token.startsWith(NON_HTTP_PREFIX)) {
14
+ const jwtParts = token.split('.');
15
+
16
+ const payload = jwtParts[1];
17
+
18
+ return JSON.parse(atob(payload));
19
+ }
20
+
21
+ const encoder = new TextEncoder();
22
+ const secretKey = await crypto.subtle.digest('SHA-256', encoder.encode(JWT_SECRET_KEY));
23
+
24
+ const jwkSecretKey = await importJWK(
25
+ { k: Buffer.from(secretKey).toString('base64'), kty: 'oct' },
26
+ 'HS256',
27
+ );
28
+
29
+ const { payload } = await jwtVerify(token, jwkSecretKey);
30
+
31
+ return payload as JWTPayload;
32
+ };