@lobehub/lobehub 2.0.0-next.127 → 2.0.0-next.129

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. package/.env.example +23 -3
  2. package/.env.example.development +5 -0
  3. package/CHANGELOG.md +50 -0
  4. package/README.md +6 -6
  5. package/README.zh-CN.md +6 -6
  6. package/changelog/v1.json +18 -0
  7. package/docker-compose/local/docker-compose.yml +24 -1
  8. package/docker-compose/local/logto/docker-compose.yml +25 -2
  9. package/docker-compose.development.yml +6 -0
  10. package/docs/development/database-schema.dbml +8 -6
  11. package/locales/ar/auth.json +114 -1
  12. package/locales/bg-BG/auth.json +114 -1
  13. package/locales/de-DE/auth.json +114 -1
  14. package/locales/en-US/auth.json +42 -22
  15. package/locales/es-ES/auth.json +114 -1
  16. package/locales/fa-IR/auth.json +114 -1
  17. package/locales/fr-FR/auth.json +114 -1
  18. package/locales/it-IT/auth.json +114 -1
  19. package/locales/ja-JP/auth.json +114 -1
  20. package/locales/ko-KR/auth.json +114 -1
  21. package/locales/nl-NL/auth.json +114 -1
  22. package/locales/pl-PL/auth.json +114 -1
  23. package/locales/pt-BR/auth.json +114 -1
  24. package/locales/ru-RU/auth.json +114 -1
  25. package/locales/tr-TR/auth.json +114 -1
  26. package/locales/vi-VN/auth.json +114 -1
  27. package/locales/zh-CN/auth.json +36 -29
  28. package/locales/zh-TW/auth.json +114 -1
  29. package/package.json +4 -1
  30. package/packages/database/migrations/0050_thread_and_user_id.sql +18 -0
  31. package/packages/database/migrations/meta/0050_snapshot.json +8792 -0
  32. package/packages/database/migrations/meta/_journal.json +7 -0
  33. package/packages/database/src/client/db.ts +21 -21
  34. package/packages/database/src/core/migrations.json +51 -10
  35. package/packages/database/src/repositories/dataImporter/deprecated/index.ts +5 -5
  36. package/packages/database/src/repositories/dataImporter/index.ts +59 -59
  37. package/packages/database/src/repositories/knowledge/index.test.ts +17 -5
  38. package/packages/database/src/repositories/knowledge/index.ts +6 -6
  39. package/packages/database/src/schemas/generation.ts +16 -16
  40. package/packages/database/src/schemas/nextauth.ts +3 -3
  41. package/packages/database/src/schemas/oidc.ts +36 -36
  42. package/packages/database/src/schemas/topic.ts +8 -3
  43. package/packages/model-runtime/src/providers/newapi/index.ts +61 -18
  44. package/packages/model-runtime/src/runtimeMap.ts +1 -0
  45. package/packages/types/src/topic/thread.ts +3 -3
  46. package/src/app/[variants]/(main)/settings/provider/features/ProviderConfig/UpdateProviderInfo/SettingModal.tsx +10 -6
  47. package/src/envs/redis.ts +106 -0
  48. package/src/libs/redis/index.ts +5 -0
  49. package/src/libs/redis/manager.test.ts +107 -0
  50. package/src/libs/redis/manager.ts +56 -0
  51. package/src/libs/redis/redis.test.ts +158 -0
  52. package/src/libs/redis/redis.ts +117 -0
  53. package/src/libs/redis/types.ts +71 -0
  54. package/src/libs/redis/upstash.test.ts +154 -0
  55. package/src/libs/redis/upstash.ts +109 -0
  56. package/src/libs/redis/utils.test.ts +46 -0
  57. package/src/libs/redis/utils.ts +53 -0
  58. package/src/store/chat/slices/thread/action.ts +1 -1
  59. package/src/store/chat/slices/thread/initialState.ts +1 -1
  60. package/src/store/chat/slices/thread/selectors/util.ts +1 -1
  61. package/.github/workflows/check-console-log.yml +0 -117
@@ -6,8 +6,8 @@ import { timestamps, timestamptz } from './_helpers';
6
6
  import { users } from './user';
7
7
 
8
8
  /**
9
- * OIDC 授权码
10
- * oidc-provider 需要持久化的模型之一
9
+ * OIDC authorization code
10
+ * One of the models that oidc-provider needs to persist
11
11
  */
12
12
  export const oidcAuthorizationCodes = pgTable('oidc_authorization_codes', {
13
13
  id: varchar('id', { length: 255 }).primaryKey(),
@@ -23,8 +23,8 @@ export const oidcAuthorizationCodes = pgTable('oidc_authorization_codes', {
23
23
  });
24
24
 
25
25
  /**
26
- * OIDC 访问令牌
27
- * oidc-provider 需要持久化的模型之一
26
+ * OIDC access token
27
+ * One of the models that oidc-provider needs to persist
28
28
  */
29
29
  export const oidcAccessTokens = pgTable('oidc_access_tokens', {
30
30
  id: varchar('id', { length: 255 }).primaryKey(),
@@ -40,8 +40,8 @@ export const oidcAccessTokens = pgTable('oidc_access_tokens', {
40
40
  });
41
41
 
42
42
  /**
43
- * OIDC 刷新令牌
44
- * oidc-provider 需要持久化的模型之一
43
+ * OIDC refresh token
44
+ * One of the models that oidc-provider needs to persist
45
45
  */
46
46
  export const oidcRefreshTokens = pgTable('oidc_refresh_tokens', {
47
47
  id: varchar('id', { length: 255 }).primaryKey(),
@@ -57,8 +57,8 @@ export const oidcRefreshTokens = pgTable('oidc_refresh_tokens', {
57
57
  });
58
58
 
59
59
  /**
60
- * OIDC 设备代码
61
- * oidc-provider 需要持久化的模型之一
60
+ * OIDC device code
61
+ * One of the models that oidc-provider needs to persist
62
62
  */
63
63
  export const oidcDeviceCodes = pgTable('oidc_device_codes', {
64
64
  id: varchar('id', { length: 255 }).primaryKey(),
@@ -73,8 +73,8 @@ export const oidcDeviceCodes = pgTable('oidc_device_codes', {
73
73
  });
74
74
 
75
75
  /**
76
- * OIDC 交互会话
77
- * oidc-provider 需要持久化的模型之一
76
+ * OIDC interaction session
77
+ * One of the models that oidc-provider needs to persist
78
78
  */
79
79
  export const oidcInteractions = pgTable('oidc_interactions', {
80
80
  id: varchar('id', { length: 255 }).primaryKey(),
@@ -84,8 +84,8 @@ export const oidcInteractions = pgTable('oidc_interactions', {
84
84
  });
85
85
 
86
86
  /**
87
- * OIDC 授权记录
88
- * oidc-provider 需要持久化的模型之一
87
+ * OIDC grant record
88
+ * One of the models that oidc-provider needs to persist
89
89
  */
90
90
  export const oidcGrants = pgTable('oidc_grants', {
91
91
  id: varchar('id', { length: 255 }).primaryKey(),
@@ -100,14 +100,14 @@ export const oidcGrants = pgTable('oidc_grants', {
100
100
  });
101
101
 
102
102
  /**
103
- * OIDC 客户端配置
104
- * 存储 OIDC 客户端配置信息
103
+ * OIDC client configuration
104
+ * Stores OIDC client configuration information
105
105
  */
106
106
  export const oidcClients = pgTable('oidc_clients', {
107
107
  id: varchar('id', { length: 255 }).primaryKey(), // client_id
108
108
  name: text('name').notNull(),
109
109
  description: text('description'),
110
- clientSecret: varchar('client_secret', { length: 255 }), // 公共客户端可为 null
110
+ clientSecret: varchar('client_secret', { length: 255 }), // Can be null for public clients
111
111
  redirectUris: text('redirect_uris').array().notNull(),
112
112
  grants: text('grants').array().notNull(),
113
113
  responseTypes: text('response_types').array().notNull(),
@@ -123,8 +123,8 @@ export const oidcClients = pgTable('oidc_clients', {
123
123
  });
124
124
 
125
125
  /**
126
- * OIDC 会话
127
- * oidc-provider 需要持久化的模型之一
126
+ * OIDC session
127
+ * One of the models that oidc-provider needs to persist
128
128
  */
129
129
  export const oidcSessions = pgTable('oidc_sessions', {
130
130
  id: varchar('id', { length: 255 }).primaryKey(),
@@ -137,8 +137,8 @@ export const oidcSessions = pgTable('oidc_sessions', {
137
137
  });
138
138
 
139
139
  /**
140
- * OIDC 授权同意记录
141
- * 记录用户对客户端的授权同意历史
140
+ * OIDC authorization consent record
141
+ * Records user authorization consent history for clients
142
142
  */
143
143
  export const oidcConsents = pgTable(
144
144
  'oidc_consents',
@@ -159,39 +159,39 @@ export const oidcConsents = pgTable(
159
159
  );
160
160
 
161
161
  /**
162
- * 通用认证凭证传递表
163
- * 用于在不同客户端(桌面端、浏览器插件、移动端等)之间安全传递认证凭证
162
+ * Generic authentication credential handoff table
163
+ * Used to securely pass authentication credentials between different clients (desktop, browser extension, mobile, etc.)
164
164
  *
165
- * 工作流程:
166
- * 1. 客户端生成唯一的 handoff ID
167
- * 2. handoff ID 作为参数附加到 OAuth redirect_uri
168
- * 3. 认证成功后,中间页将凭证存储到此表
169
- * 4. 客户端轮询此表获取凭证
170
- * 5. 成功获取后立即删除记录
165
+ * Workflow:
166
+ * 1. Client generates a unique handoff ID
167
+ * 2. Appends handoff ID as a parameter to OAuth redirect_uri
168
+ * 3. After successful authentication, intermediate page stores credentials in this table
169
+ * 4. Client polls this table to retrieve credentials
170
+ * 5. Record is immediately deleted after successful retrieval
171
171
  */
172
172
  export const oauthHandoffs = pgTable('oauth_handoffs', {
173
173
  /**
174
- * 由客户端生成的一次性唯一标识符
175
- * 用于客户端轮询时认领自己的凭证
174
+ * One-time unique identifier generated by the client
175
+ * Used for client polling to claim its own credentials
176
176
  */
177
177
  id: text('id').primaryKey(),
178
178
 
179
179
  /**
180
- * 客户端类型标识
181
- * 如: 'desktop', 'browser-extension', 'mobile-app'
180
+ * Client type identifier
181
+ * Examples: 'desktop', 'browser-extension', 'mobile-app', etc.
182
182
  */
183
183
  client: varchar('client', { length: 50 }).notNull(),
184
184
 
185
185
  /**
186
- * 凭证数据的 JSON 载荷
187
- * 灵活存储不同认证流程所需的各种数据
188
- * 当前主要包含: { code: string; state: string }
186
+ * JSON payload for credential data
187
+ * Flexible storage for various data required by different authentication flows
188
+ * Currently mainly contains: { code: string; state: string }
189
189
  */
190
190
  payload: jsonb('payload').$type<Record<string, unknown>>().notNull(),
191
191
 
192
192
  /**
193
- * 时间戳字段,用于 TTL 控制
194
- * 凭证应在创建后 5 分钟内被消费,否则视为过期
193
+ * Timestamp fields for TTL control
194
+ * Credentials should be consumed within 5 minutes of creation, otherwise considered expired
195
195
  */
196
196
  ...timestamps,
197
197
  });
@@ -49,12 +49,17 @@ export const threads = pgTable(
49
49
  .primaryKey(),
50
50
 
51
51
  title: text('title'),
52
- type: text('type', { enum: ['continuation', 'standalone'] }).notNull(),
53
- status: text('status', { enum: ['active', 'deprecated', 'archived'] }).default('active'),
52
+ content: text('content'),
53
+ editor_data: jsonb('editor_data'),
54
+ type: text('type', { enum: ['continuation', 'standalone', 'isolation'] }).notNull(),
55
+ status: text('status', {
56
+ enum: ['active', 'processing', 'pending', 'inReview', 'todo', 'cancel'],
57
+ }),
58
+
54
59
  topicId: text('topic_id')
55
60
  .references(() => topics.id, { onDelete: 'cascade' })
56
61
  .notNull(),
57
- sourceMessageId: text('source_message_id').notNull(),
62
+ sourceMessageId: text('source_message_id'),
58
63
  // @ts-ignore
59
64
  parentThreadId: text('parent_thread_id').references(() => threads.id, { onDelete: 'set null' }),
60
65
  clientId: text('client_id'),
@@ -25,6 +25,62 @@ export interface NewAPIPricing {
25
25
  supported_endpoint_types?: string[];
26
26
  }
27
27
 
28
+ /**
29
+ * Detect if running in browser environment
30
+ */
31
+ const isBrowser = () => typeof window !== 'undefined' && typeof document !== 'undefined';
32
+
33
+ /**
34
+ * Parse a pricing API HTTP response into a `NewAPIPricing[] | null`.
35
+ * Shared between browser and server branches to avoid duplicated logic.
36
+ */
37
+ const parsePricingResponse = async (res: Response): Promise<NewAPIPricing[] | null> => {
38
+ if (!res.ok) {
39
+ return null;
40
+ }
41
+
42
+ try {
43
+ const body = await res.json();
44
+ return body?.success && body?.data ? (body.data as NewAPIPricing[]) : null;
45
+ } catch {
46
+ return null;
47
+ }
48
+ };
49
+
50
+ /**
51
+ * Fetch pricing information with CORS bypass for client-side requests
52
+ * In browser environment, use /webapi/proxy to avoid CORS errors
53
+ */
54
+ const fetchPricing = async (
55
+ pricingUrl: string,
56
+ apiKey: string,
57
+ ): Promise<NewAPIPricing[] | null> => {
58
+ try {
59
+ if (isBrowser()) {
60
+ // In browser environment, use the proxy endpoint to avoid CORS
61
+ // The proxy endpoint expects the URL as the request body
62
+ const proxyResponse = await fetch('/webapi/proxy', {
63
+ body: pricingUrl,
64
+ method: 'POST',
65
+ });
66
+
67
+ return await parsePricingResponse(proxyResponse);
68
+ } else {
69
+ // In server environment, fetch directly
70
+ const pricingResponse = await fetch(pricingUrl, {
71
+ headers: {
72
+ Authorization: `Bearer ${apiKey}`,
73
+ },
74
+ });
75
+
76
+ return await parsePricingResponse(pricingResponse);
77
+ }
78
+ } catch (error) {
79
+ console.debug('Failed to fetch NewAPI pricing info:', error);
80
+ return null;
81
+ }
82
+ };
83
+
28
84
  export const params = {
29
85
  debug: {
30
86
  chatCompletion: () => process.env.DEBUG_NEWAPI_CHAT_COMPLETION === '1',
@@ -42,25 +98,12 @@ export const params = {
42
98
 
43
99
  // Try to get pricing information to enrich model details
44
100
  let pricingMap: Map<string, NewAPIPricing> = new Map();
45
- try {
46
- // Use saved baseURL
47
- const pricingResponse = await fetch(`${baseURL}/api/pricing`, {
48
- headers: {
49
- Authorization: `Bearer ${openAIClient.apiKey}`,
50
- },
51
- });
52
101
 
53
- if (pricingResponse.ok) {
54
- const pricingData = await pricingResponse.json();
55
- if (pricingData.success && pricingData.data) {
56
- (pricingData.data as NewAPIPricing[]).forEach((pricing) => {
57
- pricingMap.set(pricing.model_name, pricing);
58
- });
59
- }
60
- }
61
- } catch (error) {
62
- // If fetching pricing information fails, continue using the basic model information
63
- console.debug('Failed to fetch NewAPI pricing info:', error);
102
+ const pricingList = await fetchPricing(`${baseURL}/api/pricing`, openAIClient.apiKey || '');
103
+ if (pricingList) {
104
+ pricingList.forEach((pricing) => {
105
+ pricingMap.set(pricing.model_name, pricing);
106
+ });
64
107
  }
65
108
 
66
109
  // Process the model list: determine the provider for each model based on priority rules
@@ -112,6 +112,7 @@ export const providerRuntimeMap = {
112
112
  ppio: LobePPIOAI,
113
113
  qiniu: LobeQiniuAI,
114
114
  qwen: LobeQwenAI,
115
+ router: LobeNewAPIAI,
115
116
  sambanova: LobeSambaNovaAI,
116
117
  search1api: LobeSearch1API,
117
118
  sensenova: LobeSenseNovaAI,
@@ -16,7 +16,7 @@ export interface ThreadItem {
16
16
  id: string;
17
17
  lastActiveAt: Date;
18
18
  parentThreadId?: string;
19
- sourceMessageId: string;
19
+ sourceMessageId?: string | null;
20
20
  status: ThreadStatus;
21
21
  title: string;
22
22
  topicId: string;
@@ -27,7 +27,7 @@ export interface ThreadItem {
27
27
 
28
28
  export interface CreateThreadParams {
29
29
  parentThreadId?: string;
30
- sourceMessageId: string;
30
+ sourceMessageId?: string;
31
31
  title?: string;
32
32
  topicId: string;
33
33
  type: ThreadType;
@@ -35,7 +35,7 @@ export interface CreateThreadParams {
35
35
 
36
36
  export const createThreadSchema = z.object({
37
37
  parentThreadId: z.string().optional(),
38
- sourceMessageId: z.string(),
38
+ sourceMessageId: z.string().optional(),
39
39
  title: z.string().optional(),
40
40
  topicId: z.string(),
41
41
  type: z.nativeEnum(ThreadType),
@@ -84,12 +84,16 @@ const CreateNewProvider = memo<CreateNewProviderProps>(({ onClose, open, initial
84
84
  {
85
85
  children: (
86
86
  <Select
87
- optionRender={({ label, value }) => (
88
- <Flexbox align={'center'} gap={8} horizontal>
89
- <ProviderIcon provider={value as string} size={18} />
90
- {label}
91
- </Flexbox>
92
- )}
87
+ optionRender={({ label, value }) => {
88
+ // Map 'router' to 'newapi' for displaying the correct icon
89
+ const iconProvider = value === 'router' ? 'newapi' : (value as string);
90
+ return (
91
+ <Flexbox align={'center'} gap={8} horizontal>
92
+ <ProviderIcon provider={iconProvider} size={18} />
93
+ {label}
94
+ </Flexbox>
95
+ );
96
+ }}
93
97
  options={CUSTOM_PROVIDER_SDK_OPTIONS}
94
98
  placeholder={t('createNewAiProvider.sdkType.placeholder')}
95
99
  variant={'filled'}
@@ -0,0 +1,106 @@
1
+ /* eslint-disable sort-keys-fix/sort-keys-fix */
2
+ import { createEnv } from '@t3-oss/env-nextjs';
3
+ import { z } from 'zod';
4
+
5
+ import type { RedisConfig } from '@/libs/redis';
6
+
7
+ type UpstashRedisConfig = { token: string; url: string };
8
+
9
+ const parseNumber = (value?: string) => {
10
+ const parsed = Number.parseInt(value ?? '', 10);
11
+
12
+ return Number.isInteger(parsed) ? parsed : undefined;
13
+ };
14
+
15
+ const parseRedisTls = (value?: string) => {
16
+ if (!value) {
17
+ return false
18
+ }
19
+
20
+ const normalized = value.trim().toLowerCase();
21
+ return normalized === 'true' || normalized === '1';
22
+ };
23
+
24
+ export const getRedisEnv = () => {
25
+ return createEnv({
26
+ runtimeEnv: {
27
+ REDIS_DATABASE: parseNumber(process.env.REDIS_DATABASE),
28
+ REDIS_PASSWORD: process.env.REDIS_PASSWORD,
29
+ REDIS_PREFIX: process.env.REDIS_PREFIX || 'lobechat',
30
+ REDIS_TLS: parseRedisTls(process.env.REDIS_TLS),
31
+ REDIS_URL: process.env.REDIS_URL,
32
+ REDIS_USERNAME: process.env.REDIS_USERNAME,
33
+ UPSTASH_REDIS_REST_TOKEN: process.env.UPSTASH_REDIS_REST_TOKEN,
34
+ UPSTASH_REDIS_REST_URL: process.env.UPSTASH_REDIS_REST_URL,
35
+ },
36
+ server: {
37
+ REDIS_DATABASE: z.number().int().optional(),
38
+ REDIS_PASSWORD: z.string().optional(),
39
+ REDIS_PREFIX: z.string(),
40
+ REDIS_TLS: z.boolean().default(false),
41
+ REDIS_URL: z.string().url().optional(),
42
+ REDIS_USERNAME: z.string().optional(),
43
+ UPSTASH_REDIS_REST_TOKEN: z.string().optional(),
44
+ UPSTASH_REDIS_REST_URL: z.string().url().optional(),
45
+ },
46
+ });
47
+ };
48
+
49
+ export const redisEnv = getRedisEnv();
50
+
51
+ export const getUpstashRedisConfig = (): UpstashRedisConfig | null => {
52
+ const upstashConfigSchema = z.union([
53
+ z.object({
54
+ token: z.string(),
55
+ url: z.string().url(),
56
+ }),
57
+ z.object({
58
+ token: z.undefined().optional(),
59
+ url: z.undefined().optional(),
60
+ }),
61
+ ]);
62
+
63
+ const parsed = upstashConfigSchema.safeParse({
64
+ token: redisEnv.UPSTASH_REDIS_REST_TOKEN,
65
+ url: redisEnv.UPSTASH_REDIS_REST_URL,
66
+ });
67
+
68
+ if (!parsed.success) throw parsed.error;
69
+ if (!parsed.data.token || !parsed.data.url) return null;
70
+
71
+ return parsed.data;
72
+ };
73
+
74
+ export const getRedisConfig = (): RedisConfig => {
75
+ const prefix = redisEnv.REDIS_PREFIX;
76
+
77
+ if (redisEnv.REDIS_URL) {
78
+ return {
79
+ database: redisEnv.REDIS_DATABASE,
80
+ enabled: true,
81
+ password: redisEnv.REDIS_PASSWORD,
82
+ prefix,
83
+ provider: 'redis',
84
+ tls: redisEnv.REDIS_TLS,
85
+ url: redisEnv.REDIS_URL,
86
+ username: redisEnv.REDIS_USERNAME,
87
+ };
88
+ }
89
+
90
+ const upstashConfig = getUpstashRedisConfig();
91
+ if (upstashConfig) {
92
+ return {
93
+ enabled: true,
94
+ prefix,
95
+ provider: 'upstash',
96
+ token: upstashConfig.token,
97
+ url: upstashConfig.url,
98
+ };
99
+ }
100
+
101
+ return {
102
+ enabled: false,
103
+ prefix,
104
+ provider: false,
105
+ };
106
+ };
@@ -0,0 +1,5 @@
1
+ export * from './manager';
2
+ export * from './redis';
3
+ export * from './types';
4
+ export * from './upstash';
5
+ export * from './utils';
@@ -0,0 +1,107 @@
1
+ import { afterEach, describe, expect, it, vi } from 'vitest';
2
+
3
+ import { RedisManager, initializeRedis, resetRedisClient } from './manager';
4
+ import { DisabledRedisConfig } from './types';
5
+
6
+ const {
7
+ mockIoRedisInitialize,
8
+ mockIoRedisDisconnect,
9
+ mockUpstashInitialize,
10
+ mockUpstashDisconnect,
11
+ } = vi.hoisted(() => ({
12
+ mockIoRedisInitialize: vi.fn().mockResolvedValue(undefined),
13
+ mockIoRedisDisconnect: vi.fn().mockResolvedValue(undefined),
14
+ mockUpstashInitialize: vi.fn().mockResolvedValue(undefined),
15
+ mockUpstashDisconnect: vi.fn().mockResolvedValue(undefined),
16
+ }));
17
+
18
+ vi.mock('./redis', () => {
19
+ const IoRedisRedisProvider = vi.fn().mockImplementation((config) => ({
20
+ provider: 'redis' as const,
21
+ config,
22
+ initialize: mockIoRedisInitialize,
23
+ disconnect: mockIoRedisDisconnect,
24
+ }));
25
+
26
+ return { IoRedisRedisProvider };
27
+ });
28
+
29
+ vi.mock('./upstash', () => {
30
+ const UpstashRedisProvider = vi.fn().mockImplementation((config) => ({
31
+ provider: 'upstash' as const,
32
+ config,
33
+ initialize: mockUpstashInitialize,
34
+ disconnect: mockUpstashDisconnect,
35
+ }));
36
+
37
+ return { UpstashRedisProvider };
38
+ });
39
+
40
+ afterEach(async () => {
41
+ vi.clearAllMocks();
42
+ await RedisManager.reset();
43
+ });
44
+
45
+ describe('RedisManager', () => {
46
+ it('returns null when redis is disabled', async () => {
47
+ const config = {
48
+ enabled: false,
49
+ prefix: 'test',
50
+ provider: false,
51
+ } satisfies DisabledRedisConfig;
52
+
53
+ const instance = await initializeRedis(config);
54
+
55
+ expect(instance).toBeNull();
56
+ expect(mockIoRedisInitialize).not.toHaveBeenCalled();
57
+ expect(mockUpstashInitialize).not.toHaveBeenCalled();
58
+ });
59
+
60
+ it('initializes ioredis provider once and memoizes the instance', async () => {
61
+ const config = {
62
+ database: 0,
63
+ enabled: true,
64
+ password: 'pwd',
65
+ prefix: 'test',
66
+ provider: 'redis' as const,
67
+ tls: false,
68
+ url: 'redis://localhost:6379',
69
+ username: 'user',
70
+ };
71
+ const [first, second] = await Promise.all([initializeRedis(config), initializeRedis(config)]);
72
+
73
+ expect(first).toBe(second);
74
+ expect(mockIoRedisInitialize).toHaveBeenCalledTimes(1);
75
+ expect(mockUpstashInitialize).not.toHaveBeenCalled();
76
+ });
77
+
78
+ it('initializes upstash provider when configured', async () => {
79
+ const config = {
80
+ enabled: true,
81
+ prefix: 'test',
82
+ provider: 'upstash' as const,
83
+ token: 'token',
84
+ url: 'https://example.upstash.io',
85
+ };
86
+ const instance = await initializeRedis(config);
87
+
88
+ expect(instance?.provider).toBe('upstash');
89
+ expect(mockUpstashInitialize).toHaveBeenCalledTimes(1);
90
+ expect(mockIoRedisInitialize).not.toHaveBeenCalled();
91
+ });
92
+
93
+ it('disconnects existing provider on reset', async () => {
94
+ const config = {
95
+ enabled: true,
96
+ prefix: 'test',
97
+ provider: 'redis' as const,
98
+ tls: false,
99
+ url: 'redis://localhost:6379',
100
+ };
101
+
102
+ await initializeRedis(config);
103
+ await resetRedisClient();
104
+
105
+ expect(mockIoRedisDisconnect).toHaveBeenCalledTimes(1);
106
+ });
107
+ });
@@ -0,0 +1,56 @@
1
+ import { IoRedisRedisProvider } from './redis';
2
+ import { BaseRedisProvider, RedisConfig } from './types';
3
+ import { UpstashRedisProvider } from './upstash';
4
+
5
+ class RedisManager {
6
+ private static instance: BaseRedisProvider | null = null;
7
+ // NOTICE: initPromise keeps concurrent initialize() calls sharing the same in-flight setup,
8
+ // preventing multiple connections from being created in parallel.
9
+ private static initPromise: Promise<BaseRedisProvider | null> | null = null;
10
+
11
+ static async initialize(config: RedisConfig): Promise<BaseRedisProvider | null> {
12
+ if (RedisManager.instance) return RedisManager.instance;
13
+ if (RedisManager.initPromise) return RedisManager.initPromise;
14
+
15
+ RedisManager.initPromise = (async () => {
16
+ if (!config.enabled) {
17
+ RedisManager.instance = null;
18
+ return null;
19
+ }
20
+
21
+ let provider: BaseRedisProvider;
22
+
23
+ if (config.provider === 'redis') {
24
+ provider = new IoRedisRedisProvider(config);
25
+ } else if (config.provider === 'upstash') {
26
+ provider = new UpstashRedisProvider({ token: config.token, url: config.url });
27
+ } else {
28
+ throw new Error(`Unsupported redis provider: ${String((config as any).provider)}`);
29
+ }
30
+
31
+ await provider.initialize();
32
+ RedisManager.instance = provider;
33
+
34
+ return provider;
35
+ })().catch((error) => {
36
+ RedisManager.initPromise = null;
37
+ throw error;
38
+ });
39
+
40
+ return RedisManager.initPromise;
41
+ }
42
+
43
+ static async reset() {
44
+ if (RedisManager.instance) {
45
+ await RedisManager.instance.disconnect();
46
+ }
47
+
48
+ RedisManager.instance = null;
49
+ RedisManager.initPromise = null;
50
+ }
51
+ }
52
+
53
+ export const initializeRedis = (config: RedisConfig) => RedisManager.initialize(config);
54
+ export const resetRedisClient = () => RedisManager.reset();
55
+ export const isRedisEnabled = (config: RedisConfig) => config.enabled;
56
+ export { RedisManager };