@lobehub/chat 0.159.12 → 0.160.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. package/CHANGELOG.md +25 -0
  2. package/README.md +1 -1
  3. package/README.zh-CN.md +1 -1
  4. package/package.json +4 -2
  5. package/src/app/trpc/{[trpc] → edge/[trpc]}/route.ts +3 -3
  6. package/src/config/__tests__/server.test.ts +0 -11
  7. package/src/config/file.ts +34 -0
  8. package/src/config/server/app.ts +0 -8
  9. package/src/config/server/provider.ts +3 -3
  10. package/src/database/client/models/file.ts +2 -1
  11. package/src/database/client/schemas/files.ts +2 -2
  12. package/src/libs/agent-runtime/google/index.test.ts +20 -1
  13. package/src/libs/agent-runtime/google/index.ts +22 -9
  14. package/src/libs/agent-runtime/utils/uriParser.test.ts +29 -0
  15. package/src/libs/agent-runtime/utils/uriParser.ts +17 -9
  16. package/src/libs/trpc/client.ts +5 -3
  17. package/src/libs/trpc/index.ts +10 -34
  18. package/src/libs/trpc/init.ts +26 -0
  19. package/src/libs/trpc/middleware/password.test.ts +87 -0
  20. package/src/libs/trpc/middleware/password.ts +26 -0
  21. package/src/libs/trpc/middleware/userAuth.test.ts +44 -0
  22. package/src/libs/trpc/middleware/userAuth.ts +18 -0
  23. package/src/server/context.ts +28 -3
  24. package/src/server/files/s3.ts +58 -0
  25. package/src/server/globalConfig/index.ts +2 -0
  26. package/src/server/mock.ts +2 -2
  27. package/src/server/routers/{config → edge/config}/index.test.ts +1 -0
  28. package/src/server/routers/edge/upload.ts +16 -0
  29. package/src/server/routers/index.ts +5 -3
  30. package/src/services/__tests__/global.test.ts +4 -5
  31. package/src/services/__tests__/sync.test.ts +56 -0
  32. package/src/services/__tests__/upload.test.ts +72 -0
  33. package/src/services/_url.ts +2 -0
  34. package/src/services/file/client.test.ts +102 -34
  35. package/src/services/file/client.ts +24 -49
  36. package/src/services/file/type.ts +1 -2
  37. package/src/services/global.ts +3 -18
  38. package/src/services/sync.ts +19 -0
  39. package/src/services/upload.ts +99 -0
  40. package/src/store/chat/slices/builtinTool/action.test.ts +4 -2
  41. package/src/store/chat/slices/builtinTool/action.ts +6 -3
  42. package/src/store/file/slices/images/action.test.ts +10 -17
  43. package/src/store/file/slices/images/action.ts +4 -1
  44. package/src/store/file/slices/tts/action.test.ts +8 -14
  45. package/src/store/file/slices/tts/action.ts +4 -1
  46. package/src/store/serverConfig/selectors.ts +1 -0
  47. package/src/store/serverConfig/store.ts +10 -0
  48. package/src/store/user/slices/common/action.ts +26 -14
  49. package/src/store/user/slices/sync/action.test.ts +6 -6
  50. package/src/store/user/slices/sync/action.ts +3 -3
  51. package/src/types/serverConfig.ts +1 -0
  52. package/src/app/api/files/image/imgur.ts +0 -72
  53. package/src/app/api/files/image/route.ts +0 -42
  54. /package/src/server/routers/{config → edge/config}/__snapshots__/index.test.ts.snap +0 -0
  55. /package/src/server/routers/{config → edge/config}/index.ts +0 -0
package/CHANGELOG.md CHANGED
@@ -2,6 +2,31 @@
2
2
 
3
3
  # Changelog
4
4
 
5
+ ## [Version 0.160.0](https://github.com/lobehub/lobe-chat/compare/v0.159.12...v0.160.0)
6
+
7
+ <sup>Released on **2024-05-18**</sup>
8
+
9
+ #### ✨ Features
10
+
11
+ - **misc**: Bump version and add enable ollama env.
12
+
13
+ <br/>
14
+
15
+ <details>
16
+ <summary><kbd>Improvements and Fixes</kbd></summary>
17
+
18
+ #### What's improved
19
+
20
+ - **misc**: Bump version and add enable ollama env, closes [#2554](https://github.com/lobehub/lobe-chat/issues/2554) ([f5ce7c9](https://github.com/lobehub/lobe-chat/commit/f5ce7c9))
21
+
22
+ </details>
23
+
24
+ <div align="right">
25
+
26
+ [![](https://img.shields.io/badge/-BACK_TO_TOP-151515?style=flat-square)](#readme-top)
27
+
28
+ </div>
29
+
5
30
  ### [Version 0.159.12](https://github.com/lobehub/lobe-chat/compare/v0.159.11...v0.159.12)
6
31
 
7
32
  <sup>Released on **2024-05-15**</sup>
package/README.md CHANGED
@@ -230,7 +230,7 @@ In addition, these plugins are not limited to news aggregation, but can also ext
230
230
  | [Social Search](https://chat-preview.lobehub.com/settings/agent)<br/><sup>By **say-apps** on **2024-05-02**</sup> | The Social Search provides access to tweets, users, followers, images, media and more.<br/>`social` `twitter` `x` `search` |
231
231
  | [Search Google via Serper](https://chat-preview.lobehub.com/settings/agent)<br/><sup>By **Barry** on **2024-04-30**</sup> | Google search engine via Serper.dev free API (2500x🆓/month)<br/>`web` `search` |
232
232
 
233
- > 📊 Total plugins: [<kbd>**58**</kbd>](https://github.com/lobehub/lobe-chat-plugins)
233
+ > 📊 Total plugins: [<kbd>**56**</kbd>](https://github.com/lobehub/lobe-chat-plugins)
234
234
 
235
235
  <!-- PLUGIN LIST -->
236
236
 
package/README.zh-CN.md CHANGED
@@ -222,7 +222,7 @@ LobeChat 的插件生态系统是其核心功能的重要扩展,它极大地
222
222
  | [社交搜索](https://chat-preview.lobehub.com/settings/agent)<br/><sup>By **say-apps** on **2024-05-02**</sup> | 社交搜索提供访问推文、用户、关注者、图片、媒体等功能。<br/>`社交` `推特` `x` `搜索` |
223
223
  | [通过 Serper 搜索 Google](https://chat-preview.lobehub.com/settings/agent)<br/><sup>By **Barry** on **2024-04-30**</sup> | 通过 Serper.dev 免费 API 进行 Google 搜索引擎(每月 2500 次🆓)<br/>`网络` `搜索` |
224
224
 
225
- > 📊 Total plugins: [<kbd>**58**</kbd>](https://github.com/lobehub/lobe-chat-plugins)
225
+ > 📊 Total plugins: [<kbd>**56**</kbd>](https://github.com/lobehub/lobe-chat-plugins)
226
226
 
227
227
  <!-- PLUGIN LIST -->
228
228
 
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@lobehub/chat",
3
- "version": "0.159.12",
3
+ "version": "0.160.0",
4
4
  "description": "Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.",
5
5
  "keywords": [
6
6
  "framework",
@@ -85,6 +85,8 @@
85
85
  "@anthropic-ai/sdk": "^0.20.9",
86
86
  "@auth/core": "0.28.0",
87
87
  "@aws-sdk/client-bedrock-runtime": "^3.574.0",
88
+ "@aws-sdk/client-s3": "^3.577.0",
89
+ "@aws-sdk/s3-request-presigner": "^3.577.0",
88
90
  "@azure/openai": "1.0.0-beta.12",
89
91
  "@cfworker/json-schema": "^1.12.8",
90
92
  "@clerk/localizations": "2.0.0",
@@ -108,7 +110,7 @@
108
110
  "@vercel/speed-insights": "^1.0.10",
109
111
  "ahooks": "^3.7.11",
110
112
  "ai": "3.0.19",
111
- "antd": "5.17.0",
113
+ "antd": "^5.17.2",
112
114
  "antd-style": "^3.6.2",
113
115
  "brotli-wasm": "^3.0.0",
114
116
  "chroma-js": "^2.4.2",
@@ -3,7 +3,7 @@ import type { NextRequest } from 'next/server';
3
3
 
4
4
  import { pino } from '@/libs/logger';
5
5
  import { createContext } from '@/server/context';
6
- import { appRouter } from '@/server/routers';
6
+ import { edgeRouter } from '@/server/routers';
7
7
 
8
8
  export const runtime = 'edge';
9
9
 
@@ -14,7 +14,7 @@ const handler = (req: NextRequest) =>
14
14
  */
15
15
  createContext: () => createContext(req),
16
16
 
17
- endpoint: '/trpc',
17
+ endpoint: '/trpc/edge',
18
18
 
19
19
  onError: ({ error, path }) => {
20
20
  pino.info(`Error in tRPC handler (edge) on path: ${path}`);
@@ -22,7 +22,7 @@ const handler = (req: NextRequest) =>
22
22
  },
23
23
 
24
24
  req,
25
- router: appRouter,
25
+ router: edgeRouter,
26
26
  });
27
27
 
28
28
  export { handler as GET, handler as POST };
@@ -33,17 +33,6 @@ describe('getServerConfig', () => {
33
33
  expect(config.OPENAI_FUNCTION_REGIONS).toStrictEqual(['iad1', 'sfo1']);
34
34
  });
35
35
 
36
- it('returns default IMGUR_CLIENT_ID when no environment variable is set', () => {
37
- const config = getServerConfig();
38
- expect(config.IMGUR_CLIENT_ID).toBe('e415f320d6e24f9');
39
- });
40
-
41
- it('returns custom IMGUR_CLIENT_ID when environment variable is set', () => {
42
- process.env.IMGUR_CLIENT_ID = 'custom-client-id';
43
- const config = getServerConfig();
44
- expect(config.IMGUR_CLIENT_ID).toBe('custom-client-id');
45
- });
46
-
47
36
  describe('index url', () => {
48
37
  it('should return default URLs when no environment variables are set', () => {
49
38
  const config = getServerConfig();
@@ -0,0 +1,34 @@
1
+ import { createEnv } from '@t3-oss/env-nextjs';
2
+ import { z } from 'zod';
3
+
4
+ const DEFAULT_S3_FILE_PATH = 'files';
5
+
6
+ export const getFileConfig = () => {
7
+ return createEnv({
8
+ client: {
9
+ NEXT_PUBLIC_S3_DOMAIN: z.string().optional(),
10
+ NEXT_PUBLIC_S3_FILE_PATH: z.string().optional(),
11
+ },
12
+ runtimeEnv: {
13
+ NEXT_PUBLIC_S3_DOMAIN: process.env.NEXT_PUBLIC_S3_DOMAIN,
14
+ NEXT_PUBLIC_S3_FILE_PATH: process.env.NEXT_PUBLIC_S3_FILE_PATH || DEFAULT_S3_FILE_PATH,
15
+
16
+ S3_ACCESS_KEY_ID: process.env.S3_ACCESS_KEY_ID,
17
+ S3_BUCKET: process.env.S3_BUCKET,
18
+ S3_ENDPOINT: process.env.S3_ENDPOINT,
19
+ S3_REGION: process.env.S3_REGION,
20
+ S3_SECRET_ACCESS_KEY: process.env.S3_SECRET_ACCESS_KEY,
21
+ },
22
+ server: {
23
+ // S3
24
+ S3_ACCESS_KEY_ID: z.string().optional(),
25
+ S3_BUCKET: z.string().optional(),
26
+ S3_ENDPOINT: z.string().optional(),
27
+
28
+ S3_REGION: z.string().optional(),
29
+ S3_SECRET_ACCESS_KEY: z.string().optional(),
30
+ },
31
+ });
32
+ };
33
+
34
+ export const fileEnv = getFileConfig();
@@ -6,8 +6,6 @@ declare global {
6
6
  interface ProcessEnv {
7
7
  ACCESS_CODE?: string;
8
8
 
9
- IMGUR_CLIENT_ID?: string;
10
-
11
9
  SITE_URL?: string;
12
10
 
13
11
  AGENTS_INDEX_URL?: string;
@@ -25,10 +23,6 @@ declare global {
25
23
  }
26
24
  }
27
25
 
28
- // we apply a free imgur app to get a client id
29
- // refs: https://apidocs.imgur.com/
30
- const DEFAULT_IMAGUR_CLIENT_ID = 'e415f320d6e24f9';
31
-
32
26
  export const getAppConfig = () => {
33
27
  if (typeof process === 'undefined') {
34
28
  throw new Error('[Server Config] you are importing a server-only module outside of server');
@@ -45,8 +39,6 @@ export const getAppConfig = () => {
45
39
 
46
40
  SITE_URL: process.env.SITE_URL,
47
41
 
48
- IMGUR_CLIENT_ID: process.env.IMGUR_CLIENT_ID || DEFAULT_IMAGUR_CLIENT_ID,
49
-
50
42
  AGENTS_INDEX_URL: !!process.env.AGENTS_INDEX_URL
51
43
  ? process.env.AGENTS_INDEX_URL
52
44
  : 'https://chat-agents.lobehub.com',
@@ -22,7 +22,7 @@ declare global {
22
22
  // DeepSeek Provider
23
23
  ENABLED_DEEPSEEK?: string;
24
24
  DEEPSEEK_API_KEY?: string;
25
-
25
+
26
26
  // ZhiPu Provider
27
27
  ENABLED_ZHIPU?: string;
28
28
  ZHIPU_API_KEY?: string;
@@ -113,7 +113,7 @@ export const getProviderConfig = () => {
113
113
  const AWS_ACCESS_KEY_ID = process.env.AWS_ACCESS_KEY_ID || '';
114
114
 
115
115
  const DEEPSEEK_API_KEY = process.env.DEEPSEEK_API_KEY || '';
116
-
116
+
117
117
  const GOOGLE_API_KEY = process.env.GOOGLE_API_KEY || '';
118
118
 
119
119
  const MOONSHOT_API_KEY = process.env.MOONSHOT_API_KEY || '';
@@ -221,7 +221,7 @@ export const getProviderConfig = () => {
221
221
  AWS_ACCESS_KEY_ID: AWS_ACCESS_KEY_ID,
222
222
  AWS_SECRET_ACCESS_KEY: process.env.AWS_SECRET_ACCESS_KEY || '',
223
223
 
224
- ENABLE_OLLAMA: Boolean(process.env.ENABLE_OLLAMA),
224
+ ENABLE_OLLAMA: process.env.ENABLE_OLLAMA === '0',
225
225
  OLLAMA_PROXY_URL: process.env.OLLAMA_PROXY_URL || '',
226
226
  OLLAMA_MODEL_LIST: process.env.OLLAMA_MODEL_LIST || process.env.OLLAMA_CUSTOM_MODELS,
227
227
  };
@@ -1,3 +1,4 @@
1
+ import { DBModel } from '@/database/client/core/types/db';
1
2
  import { DB_File, DB_FileSchema } from '@/database/client/schemas/files';
2
3
  import { nanoid } from '@/utils/uuid';
3
4
 
@@ -14,7 +15,7 @@ class _FileModel extends BaseModel<'files'> {
14
15
  return this._addWithSync(file, `file-${id}`);
15
16
  }
16
17
 
17
- async findById(id: string) {
18
+ async findById(id: string): Promise<DBModel<DB_File>> {
18
19
  return this.table.get(id);
19
20
  }
20
21
 
@@ -8,7 +8,7 @@ export const DB_FileSchema = z.object({
8
8
  /**
9
9
  * file data array buffer
10
10
  */
11
- data: z.instanceof(ArrayBuffer),
11
+ data: z.instanceof(ArrayBuffer).optional(),
12
12
  /**
13
13
  * file type
14
14
  * @example 'image/png'
@@ -33,7 +33,7 @@ export const DB_FileSchema = z.object({
33
33
  /**
34
34
  * file url if saveMode is url
35
35
  */
36
- url: z.string().url().optional(),
36
+ url: z.string().optional(),
37
37
  });
38
38
 
39
39
  export type DB_File = z.infer<typeof DB_FileSchema>;
@@ -560,7 +560,7 @@ describe('LobeGoogleAI', () => {
560
560
  });
561
561
  });
562
562
 
563
- it('should correctly convert message with content parts', () => {
563
+ it('should correctly convert message with inline base64 image parts', () => {
564
564
  const message: OpenAIChatMessage = {
565
565
  role: 'user',
566
566
  content: [
@@ -571,6 +571,25 @@ describe('LobeGoogleAI', () => {
571
571
 
572
572
  const converted = instance['convertOAIMessagesToGoogleMessage'](message);
573
573
 
574
+ expect(converted).toEqual({
575
+ role: 'user',
576
+ parts: [
577
+ { text: 'Check this image:' },
578
+ { inlineData: { data: '...', mimeType: 'image/png' } },
579
+ ],
580
+ });
581
+ });
582
+ it.skip('should correctly convert message with image url parts', () => {
583
+ const message: OpenAIChatMessage = {
584
+ role: 'user',
585
+ content: [
586
+ { type: 'text', text: 'Check this image:' },
587
+ { type: 'image_url', image_url: { url: 'https://image-file.com' } },
588
+ ],
589
+ };
590
+
591
+ const converted = instance['convertOAIMessagesToGoogleMessage'](message);
592
+
574
593
  expect(converted).toEqual({
575
594
  role: 'user',
576
595
  parts: [
@@ -115,18 +115,31 @@ export class LobeGoogleAI implements LobeRuntimeAI {
115
115
  return { text: content.text };
116
116
  }
117
117
  case 'image_url': {
118
- const { mimeType, base64 } = parseDataUri(content.image_url.url);
118
+ const { mimeType, base64, type } = parseDataUri(content.image_url.url);
119
119
 
120
- if (!base64) {
121
- throw new TypeError("Image URL doesn't contain base64 data");
120
+ if (type === 'base64') {
121
+ if (!base64) {
122
+ throw new TypeError("Image URL doesn't contain base64 data");
123
+ }
124
+
125
+ return {
126
+ inlineData: {
127
+ data: base64,
128
+ mimeType: mimeType || 'image/png',
129
+ },
130
+ };
122
131
  }
123
132
 
124
- return {
125
- inlineData: {
126
- data: base64,
127
- mimeType: mimeType || 'image/png',
128
- },
129
- };
133
+ // if (type === 'url') {
134
+ // return {
135
+ // fileData: {
136
+ // fileUri: content.image_url.url,
137
+ // mimeType: mimeType || 'image/png',
138
+ // },
139
+ // };
140
+ // }
141
+
142
+ throw new TypeError(`currently we don't support image url: ${content.image_url.url}`);
130
143
  }
131
144
  }
132
145
  };
@@ -0,0 +1,29 @@
1
+ import { describe, expect, it } from 'vitest';
2
+
3
+ import { parseDataUri } from './uriParser';
4
+
5
+ describe('parseDataUri', () => {
6
+ it('should parse a valid data URI', () => {
7
+ const dataUri = 'data:image/png;base64,abc';
8
+ const result = parseDataUri(dataUri);
9
+ expect(result).toEqual({ base64: 'abc', mimeType: 'image/png', type: 'base64' });
10
+ });
11
+
12
+ it('should parse a valid URL', () => {
13
+ const url = 'https://example.com/image.jpg';
14
+ const result = parseDataUri(url);
15
+ expect(result).toEqual({ base64: null, mimeType: null, type: 'url' });
16
+ });
17
+
18
+ it('should return null for an invalid input', () => {
19
+ const invalidInput = 'invalid-data';
20
+ const result = parseDataUri(invalidInput);
21
+ expect(result).toEqual({ base64: null, mimeType: null, type: null });
22
+ });
23
+
24
+ it('should handle an empty input', () => {
25
+ const emptyInput = '';
26
+ const result = parseDataUri(emptyInput);
27
+ expect(result).toEqual({ base64: null, mimeType: null, type: null });
28
+ });
29
+ });
@@ -1,16 +1,24 @@
1
- export const parseDataUri = (
2
- dataUri: string,
3
- ): { base64: string | null; mimeType: string | null } => {
1
+ interface UriParserResult {
2
+ base64: string | null;
3
+ mimeType: string | null;
4
+ type: 'url' | 'base64' | null;
5
+ }
6
+
7
+ export const parseDataUri = (dataUri: string): UriParserResult => {
4
8
  // 正则表达式匹配整个 Data URI 结构
5
9
  const dataUriMatch = dataUri.match(/^data:([^;]+);base64,(.+)$/);
6
10
 
7
- // 如果匹配成功,则返回 mimeType 和 base64,否则返回 null
8
11
  if (dataUriMatch) {
9
- return {
10
- base64: dataUriMatch[2],
11
- mimeType: dataUriMatch[1],
12
- };
12
+ // 如果是合法的 Data URI
13
+ return { base64: dataUriMatch[2], mimeType: dataUriMatch[1], type: 'base64' };
13
14
  }
14
15
 
15
- return { base64: null, mimeType: null };
16
+ try {
17
+ new URL(dataUri);
18
+ // 如果是合法的 URL
19
+ return { base64: null, mimeType: null, type: 'url' };
20
+ } catch {
21
+ // 既不是 Data URI 也不是合法 URL
22
+ return { base64: null, mimeType: null, type: null };
23
+ }
16
24
  };
@@ -1,13 +1,15 @@
1
1
  import { createTRPCClient, httpBatchLink } from '@trpc/client';
2
2
  import superjson from 'superjson';
3
3
 
4
- import type { AppRouter } from '@/server/routers';
4
+ import type { EdgeRouter } from '@/server/routers';
5
+ import { createHeaderWithAuth } from '@/services/_auth';
5
6
 
6
- export const trpcClient = createTRPCClient<AppRouter>({
7
+ export const edgeClient = createTRPCClient<EdgeRouter>({
7
8
  links: [
8
9
  httpBatchLink({
10
+ headers: async () => createHeaderWithAuth(),
9
11
  transformer: superjson,
10
- url: '/trpc',
12
+ url: '/trpc/edge',
11
13
  }),
12
14
  ],
13
15
  });
@@ -7,60 +7,36 @@
7
7
  * @link https://trpc.io/docs/v11/router
8
8
  * @link https://trpc.io/docs/v11/procedures
9
9
  */
10
- import { TRPCError, initTRPC } from '@trpc/server';
11
- import superjson from 'superjson';
12
-
13
- import type { Context } from '@/server/context';
14
-
15
- const t = initTRPC.context<Context>().create({
16
- /**
17
- * @link https://trpc.io/docs/v11/error-formatting
18
- */
19
- errorFormatter({ shape }) {
20
- return shape;
21
- },
22
- /**
23
- * @link https://trpc.io/docs/v11/data-transformers
24
- */
25
- transformer: superjson,
26
- });
10
+ import { trpc } from './init';
11
+ import { passwordChecker } from './middleware/password';
12
+ import { userAuth } from './middleware/userAuth';
27
13
 
28
14
  /**
29
15
  * Create a router
30
16
  * @link https://trpc.io/docs/v11/router
31
17
  */
32
- export const router = t.router;
18
+ export const router = trpc.router;
33
19
 
34
20
  /**
35
21
  * Create an unprotected procedure
36
22
  * @link https://trpc.io/docs/v11/procedures
37
23
  **/
38
- export const publicProcedure = t.procedure;
24
+ export const publicProcedure = trpc.procedure;
39
25
 
40
26
  // procedure that asserts that the user is logged in
41
- export const authedProcedure = t.procedure.use(async (opts) => {
42
- const { ctx } = opts;
43
- // `ctx.user` is nullable
44
- if (!ctx.userId) {
45
- throw new TRPCError({ code: 'UNAUTHORIZED' });
46
- }
27
+ export const authedProcedure = trpc.procedure.use(userAuth);
47
28
 
48
- return opts.next({
49
- ctx: {
50
- // ✅ user value is known to be non-null now
51
- userId: ctx.userId,
52
- },
53
- });
54
- });
29
+ // procedure that asserts that the user add the password
30
+ export const passwordProcedure = trpc.procedure.use(passwordChecker);
55
31
 
56
32
  /**
57
33
  * Merge multiple routers together
58
34
  * @link https://trpc.io/docs/v11/merging-routers
59
35
  */
60
- export const mergeRouters = t.mergeRouters;
36
+ export const mergeRouters = trpc.mergeRouters;
61
37
 
62
38
  /**
63
39
  * Create a server-side caller
64
40
  * @link https://trpc.io/docs/v11/server/server-side-calls
65
41
  */
66
- export const createCallerFactory = t.createCallerFactory;
42
+ export const createCallerFactory = trpc.createCallerFactory;
@@ -0,0 +1,26 @@
1
+ /**
2
+ * This is your entry point to setup the root configuration for tRPC on the server.
3
+ * - `initTRPC` should only be used once per app.
4
+ * - We export only the functionality that we use so we can enforce which base procedures should be used
5
+ *
6
+ * Learn how to create protected base procedures and other things below:
7
+ * @link https://trpc.io/docs/v11/router
8
+ * @link https://trpc.io/docs/v11/procedures
9
+ */
10
+ import { initTRPC } from '@trpc/server';
11
+ import superjson from 'superjson';
12
+
13
+ import type { Context } from '@/server/context';
14
+
15
+ export const trpc = initTRPC.context<Context>().create({
16
+ /**
17
+ * @link https://trpc.io/docs/v11/error-formatting
18
+ */
19
+ errorFormatter({ shape }) {
20
+ return shape;
21
+ },
22
+ /**
23
+ * @link https://trpc.io/docs/v11/data-transformers
24
+ */
25
+ transformer: superjson,
26
+ });
@@ -0,0 +1,87 @@
1
+ import { TRPCError } from '@trpc/server';
2
+ import { beforeEach, describe, expect, it, vi } from 'vitest';
3
+
4
+ import * as utils from '@/app/api/middleware/auth/utils';
5
+ import * as serverConfig from '@/config/server';
6
+ import { createCallerFactory } from '@/libs/trpc';
7
+ import { trpc } from '@/libs/trpc/init';
8
+ import { AuthContext, createContextInner } from '@/server/context';
9
+
10
+ import { passwordChecker } from './password';
11
+
12
+ const appRouter = trpc.router({
13
+ protectedQuery: trpc.procedure.use(passwordChecker).query(async ({ ctx }) => {
14
+ return ctx.jwtPayload;
15
+ }),
16
+ });
17
+
18
+ const createCaller = createCallerFactory(appRouter);
19
+ let ctx: AuthContext;
20
+ let router: ReturnType<typeof createCaller>;
21
+
22
+ beforeEach(() => {
23
+ vi.resetAllMocks();
24
+ });
25
+
26
+ describe('passwordChecker middleware', () => {
27
+ it('should throw UNAUTHORIZED error if authorizationHeader is not present in context', async () => {
28
+ ctx = await createContextInner();
29
+ router = createCaller(ctx);
30
+
31
+ await expect(router.protectedQuery()).rejects.toThrow(new TRPCError({ code: 'UNAUTHORIZED' }));
32
+ });
33
+
34
+ it('should throw UNAUTHORIZED error if access code is not correct', async () => {
35
+ vi.spyOn(serverConfig, 'getServerConfig').mockReturnValue({
36
+ ACCESS_CODES: ['123'],
37
+ } as any);
38
+ vi.spyOn(utils, 'getJWTPayload').mockResolvedValue({ accessCode: '456' });
39
+
40
+ ctx = await createContextInner({ authorizationHeader: 'Bearer token' });
41
+ router = createCaller(ctx);
42
+
43
+ await expect(router.protectedQuery()).rejects.toThrow(new TRPCError({ code: 'UNAUTHORIZED' }));
44
+ });
45
+
46
+ it('should call next with jwtPayload in context if access code is correct', async () => {
47
+ const jwtPayload = { accessCode: '123' };
48
+ vi.spyOn(serverConfig, 'getServerConfig').mockReturnValue({
49
+ ACCESS_CODES: ['123'],
50
+ } as any);
51
+ vi.spyOn(utils, 'getJWTPayload').mockResolvedValue(jwtPayload);
52
+
53
+ ctx = await createContextInner({ authorizationHeader: 'Bearer token' });
54
+ router = createCaller(ctx);
55
+
56
+ const data = await router.protectedQuery();
57
+
58
+ expect(data).toEqual(jwtPayload);
59
+ });
60
+
61
+ it('should call next with jwtPayload in context if no access codes are set', async () => {
62
+ const jwtPayload = {};
63
+ vi.spyOn(serverConfig, 'getServerConfig').mockReturnValue({
64
+ ACCESS_CODES: [],
65
+ } as any);
66
+ vi.spyOn(utils, 'getJWTPayload').mockResolvedValue(jwtPayload);
67
+
68
+ ctx = await createContextInner({ authorizationHeader: 'Bearer token' });
69
+ router = createCaller(ctx);
70
+
71
+ const data = await router.protectedQuery();
72
+
73
+ expect(data).toEqual(jwtPayload);
74
+ });
75
+ it('should call next with jwtPayload in context if access codes is undefined', async () => {
76
+ const jwtPayload = {};
77
+ vi.spyOn(serverConfig, 'getServerConfig').mockReturnValue({} as any);
78
+ vi.spyOn(utils, 'getJWTPayload').mockResolvedValue(jwtPayload);
79
+
80
+ ctx = await createContextInner({ authorizationHeader: 'Bearer token' });
81
+ router = createCaller(ctx);
82
+
83
+ const data = await router.protectedQuery();
84
+
85
+ expect(data).toEqual(jwtPayload);
86
+ });
87
+ });
@@ -0,0 +1,26 @@
1
+ import { TRPCError } from '@trpc/server';
2
+
3
+ import { getJWTPayload } from '@/app/api/middleware/auth/utils';
4
+ import { getServerConfig } from '@/config/server';
5
+ import { trpc } from '@/libs/trpc/init';
6
+
7
+ export const passwordChecker = trpc.middleware(async (opts) => {
8
+ const { ACCESS_CODES } = getServerConfig();
9
+
10
+ const { ctx } = opts;
11
+
12
+ if (!ctx.authorizationHeader) throw new TRPCError({ code: 'UNAUTHORIZED' });
13
+
14
+ const jwtPayload = await getJWTPayload(ctx.authorizationHeader);
15
+
16
+ // if there are access codes, check if the user has set correct one
17
+ if (ACCESS_CODES && ACCESS_CODES.length > 0) {
18
+ const accessCode = jwtPayload.accessCode;
19
+
20
+ if (!accessCode || !ACCESS_CODES.includes(accessCode)) {
21
+ throw new TRPCError({ code: 'UNAUTHORIZED' });
22
+ }
23
+ }
24
+
25
+ return opts.next({ ctx: { jwtPayload } });
26
+ });
@@ -0,0 +1,44 @@
1
+ import { TRPCError } from '@trpc/server';
2
+ import { beforeEach, describe, expect, it, vi } from 'vitest';
3
+
4
+ import { createCallerFactory } from '@/libs/trpc';
5
+ import { AuthContext, createContextInner } from '@/server/context';
6
+
7
+ import { trpc } from '../init';
8
+ import { userAuth } from './userAuth';
9
+
10
+ const appRouter = trpc.router({
11
+ protectedQuery: trpc.procedure.use(userAuth).query(async ({ ctx }) => {
12
+ return ctx.userId;
13
+ }),
14
+ });
15
+
16
+ const createCaller = createCallerFactory(appRouter);
17
+ let ctx: AuthContext;
18
+ let router: ReturnType<typeof createCaller>;
19
+
20
+ beforeEach(async () => {
21
+ vi.resetAllMocks();
22
+ });
23
+
24
+ describe('userAuth middleware', () => {
25
+ it('should throw UNAUTHORIZED error if userId is not present in context', async () => {
26
+ ctx = await createContextInner();
27
+ router = createCaller(ctx);
28
+
29
+ try {
30
+ await router.protectedQuery();
31
+ } catch (e) {
32
+ expect(e).toEqual(new TRPCError({ code: 'UNAUTHORIZED' }));
33
+ }
34
+ });
35
+
36
+ it('should call next with userId in context if userId is present', async () => {
37
+ ctx = await createContextInner({ userId: 'user-id' });
38
+ router = createCaller(ctx);
39
+
40
+ const data = await router.protectedQuery();
41
+
42
+ expect(data).toEqual('user-id');
43
+ });
44
+ });