@thinkbun/middleware 1.0.1 → 1.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/CHANGELOG.md CHANGED
@@ -1,5 +1,21 @@
1
1
  # @thinkbun/middleware
2
2
 
3
+ ## 1.0.3
4
+
5
+ ### Patch Changes
6
+
7
+ - feat(core): 增强 CLI 应用和核心框架功能
8
+ - Updated dependencies
9
+ - @thinkbun/core@1.0.6
10
+
11
+ ## 1.0.2
12
+
13
+ ### Patch Changes
14
+
15
+ - 替换 ajv 为 typebox 进行数据校验
16
+ - Updated dependencies
17
+ - @thinkbun/core@1.0.5
18
+
3
19
  ## 1.0.1
4
20
 
5
21
  ### Patch Changes
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@thinkbun/middleware",
3
- "version": "1.0.1",
3
+ "version": "1.0.3",
4
4
  "module": "src/index.ts",
5
5
  "type": "module",
6
6
  "devDependencies": {
@@ -11,7 +11,7 @@
11
11
  },
12
12
  "dependencies": {
13
13
  "picocolors": "^1.1.1",
14
- "@thinkbun/core": "workspace:*"
14
+ "@thinkbun/core": "1.0.4"
15
15
  },
16
16
  "publishConfig": {
17
17
  "access": "public"
package/src/index.ts CHANGED
@@ -3,15 +3,12 @@ import { cors as _cors, type CorsOptions } from './middlewares/cors';
3
3
  import { errorHandler as _errorHandler, type ErrorHandlerOptions } from './middlewares/errorHandler';
4
4
  import { logger as _logger, type LoggerOptions } from './middlewares/logger';
5
5
  import { staticServe as _staticServe, type StaticOptions } from './middlewares/static';
6
- import { validator as _validator, type ValidatorOptions, type ValidationRule } from './middlewares/validator';
7
6
 
8
7
  // 重新导出各个中间件
9
8
  export { _errorHandler as errorHandler };
10
9
  export type { ErrorHandlerOptions };
11
10
  export { _logger as logger };
12
11
  export type { LoggerOptions };
13
- export { _validator as validator };
14
- export type { ValidatorOptions, ValidationRule };
15
12
  export { _cors as cors };
16
13
  export type { CorsOptions };
17
14
  export { _staticServe as staticServe };
@@ -24,7 +21,6 @@ export type { Middleware, Context } from '@thinkbun/core';
24
21
  export const middlewares = {
25
22
  errorHandler: _errorHandler,
26
23
  logger: _logger,
27
- validator: _validator,
28
24
  cors: _cors,
29
25
  staticServe: _staticServe,
30
26
  };
@@ -38,6 +38,18 @@ export interface CorsOptions {
38
38
  preflightContinue?: boolean;
39
39
  }
40
40
 
41
+ /**
42
+ * CORS中间件函数
43
+ *
44
+ * 处理跨域资源共享,支持:
45
+ * - 自定义允许的源
46
+ * - 自定义HTTP方法
47
+ * - 预检请求处理
48
+ * - 凭证支持
49
+ *
50
+ * @param options CORS配置选项
51
+ * @returns 返回一个中间件函数
52
+ */
41
53
  export function cors(options: CorsOptions = {}): Middleware {
42
54
  const {
43
55
  origin = '*',
@@ -52,7 +64,12 @@ export function cors(options: CorsOptions = {}): Middleware {
52
64
  // 格式化methods为字符串
53
65
  const methodsStr = methods.join(', ').toUpperCase();
54
66
 
55
- // 检查origin是否允许
67
+ /**
68
+ * 检查origin是否允许
69
+ * @private
70
+ * @param originHeader 请求头中的Origin值
71
+ * @returns 允许的origin字符串或false
72
+ */
56
73
  const checkOrigin = (originHeader: string | null): string | false => {
57
74
  if (!originHeader) return false;
58
75
 
@@ -75,6 +92,34 @@ export function cors(options: CorsOptions = {}): Middleware {
75
92
  return false;
76
93
  };
77
94
 
95
+ /**
96
+ * 设置CORS响应头
97
+ * @private
98
+ * @param headers Headers对象
99
+ * @param allowedOrigin 允许的origin
100
+ * @param isPreflight 是否为预检请求
101
+ */
102
+ const setCORSHeaders = (headers: Headers, allowedOrigin: string, isPreflight = false) => {
103
+ headers.set('Access-Control-Allow-Origin', allowedOrigin);
104
+
105
+ if (isPreflight) {
106
+ headers.set('Access-Control-Allow-Methods', methodsStr);
107
+ headers.set(
108
+ 'Access-Control-Allow-Headers',
109
+ allowedHeaders?.join(', ') || headers.get('Access-Control-Request-Headers') || '',
110
+ );
111
+ headers.set('Access-Control-Max-Age', maxAge.toString());
112
+ }
113
+
114
+ if (credentials) {
115
+ headers.set('Access-Control-Allow-Credentials', 'true');
116
+ }
117
+
118
+ if (exposedHeaders && exposedHeaders.length > 0 && !isPreflight) {
119
+ headers.set('Access-Control-Expose-Headers', exposedHeaders.join(', '));
120
+ }
121
+ };
122
+
78
123
  return async (ctx: Context, next: () => Promise<Response | Context | void>): Promise<Response | Context | void> => {
79
124
  // 处理OPTIONS预检请求
80
125
  if (ctx.method === 'OPTIONS') {
@@ -87,16 +132,11 @@ export function cors(options: CorsOptions = {}): Middleware {
87
132
  }
88
133
 
89
134
  // 构建预检响应
90
- const preflightResponse = new Response(null, {
135
+ const preflightHeaders = new Headers();
136
+ setCORSHeaders(preflightHeaders, allowedOrigin, true);
137
+ const preflightResponse = new Response(null, {
91
138
  status: 204,
92
- headers: {
93
- 'Access-Control-Allow-Origin': allowedOrigin,
94
- 'Access-Control-Allow-Methods': methodsStr,
95
- 'Access-Control-Allow-Headers':
96
- allowedHeaders?.join(', ') || ctx.header('Access-Control-Request-Headers') || '',
97
- 'Access-Control-Max-Age': maxAge.toString(),
98
- ...(credentials && { 'Access-Control-Allow-Credentials': 'true' }),
99
- },
139
+ headers: preflightHeaders,
100
140
  });
101
141
 
102
142
  if (!preflightContinue) {
@@ -110,7 +150,7 @@ export function cors(options: CorsOptions = {}): Middleware {
110
150
  if (nextResponse instanceof Response) {
111
151
  const headers = new Headers(nextResponse.headers);
112
152
 
113
- // 复制preflightResponse的所有头到新响应
153
+ // 复制预检响应的所有头到新响应
114
154
  preflightResponse.headers.forEach((value, name) => {
115
155
  headers.set(name, value);
116
156
  });
@@ -143,17 +183,9 @@ export function cors(options: CorsOptions = {}): Middleware {
143
183
  const response = await next();
144
184
 
145
185
  // 设置CORS响应头
146
- const setCorsHeaders = (res: Response) => {
186
+ const setCORSHeadersOnResponse = (res: Response) => {
147
187
  const headers = new Headers(res.headers);
148
- headers.set('Access-Control-Allow-Origin', allowedOrigin);
149
-
150
- if (credentials) {
151
- headers.set('Access-Control-Allow-Credentials', 'true');
152
- }
153
-
154
- if (exposedHeaders && exposedHeaders.length > 0) {
155
- headers.set('Access-Control-Expose-Headers', exposedHeaders.join(', '));
156
- }
188
+ setCORSHeaders(headers, allowedOrigin);
157
189
 
158
190
  return new Response(res.body, {
159
191
  status: res.status,
@@ -163,18 +195,10 @@ export function cors(options: CorsOptions = {}): Middleware {
163
195
  };
164
196
 
165
197
  if (response instanceof Response) {
166
- return setCorsHeaders(response);
198
+ return setCORSHeadersOnResponse(response);
167
199
  } else {
168
200
  // 如果返回的是上下文对象或undefined,设置响应头
169
- ctx.responseHeaders.set('Access-Control-Allow-Origin', allowedOrigin);
170
-
171
- if (credentials) {
172
- ctx.responseHeaders.set('Access-Control-Allow-Credentials', 'true');
173
- }
174
-
175
- if (exposedHeaders && exposedHeaders.length > 0) {
176
- ctx.responseHeaders.set('Access-Control-Expose-Headers', exposedHeaders.join(', '));
177
- }
201
+ setCORSHeaders(ctx.responseHeaders, allowedOrigin);
178
202
  }
179
203
 
180
204
  return response;
@@ -21,6 +21,51 @@ export interface ErrorHandlerOptions {
21
21
  logger?: (error: Error, ctx: Context) => void;
22
22
  }
23
23
 
24
+ /**
25
+ * 错误类型映射表
26
+ */
27
+ const ERROR_TYPE_MAP: Record<string, { status: number; defaultMessage: string }> = {
28
+ BadRequestError: { status: 400, defaultMessage: 'Bad Request' },
29
+ UnauthorizedError: { status: 401, defaultMessage: 'Unauthorized' },
30
+ ForbiddenError: { status: 403, defaultMessage: 'Forbidden' },
31
+ NotFoundError: { status: 404, defaultMessage: 'Not Found' },
32
+ MethodNotAllowedError: { status: 405, defaultMessage: 'Method Not Allowed' },
33
+ ConflictError: { status: 409, defaultMessage: 'Conflict' },
34
+ UnprocessableEntityError: { status: 422, defaultMessage: 'Unprocessable Entity' },
35
+ };
36
+
37
+ /**
38
+ * 根据错误类型获取状态码和消息
39
+ * @private
40
+ * @param error 错误对象
41
+ * @returns 包含状态码和消息的对象
42
+ */
43
+ const getErrorInfo = (error: Error) => {
44
+ const errorType = ERROR_TYPE_MAP[error.name];
45
+ if (errorType) {
46
+ return {
47
+ status: errorType.status,
48
+ message: error.message || errorType.defaultMessage,
49
+ };
50
+ }
51
+ return {
52
+ status: 500,
53
+ message: 'Internal Server Error',
54
+ };
55
+ };
56
+
57
+ /**
58
+ * 错误处理中间件函数
59
+ *
60
+ * 捕获和处理应用中的错误,支持:
61
+ * - 自定义错误处理函数
62
+ * - 错误日志记录
63
+ * - 根据错误类型设置HTTP状态码
64
+ * - 生产环境下的错误信息隐藏
65
+ *
66
+ * @param options 错误处理配置选项
67
+ * @returns 返回一个中间件函数
68
+ */
24
69
  export function errorHandler(options: ErrorHandlerOptions = {}): Middleware {
25
70
  const { exposeStackTrace = process.env.NODE_ENV !== 'production', customHandler, logger = console.error } = options;
26
71
 
@@ -41,32 +86,7 @@ export function errorHandler(options: ErrorHandlerOptions = {}): Middleware {
41
86
 
42
87
  // 默认错误处理
43
88
  const err = error as Error;
44
- let status = 500;
45
- let message = 'Internal Server Error';
46
-
47
- // 根据错误类型设置状态码
48
- if (err.name === 'BadRequestError') {
49
- status = 400;
50
- message = err.message || 'Bad Request';
51
- } else if (err.name === 'UnauthorizedError') {
52
- status = 401;
53
- message = err.message || 'Unauthorized';
54
- } else if (err.name === 'ForbiddenError') {
55
- status = 403;
56
- message = err.message || 'Forbidden';
57
- } else if (err.name === 'NotFoundError') {
58
- status = 404;
59
- message = err.message || 'Not Found';
60
- } else if (err.name === 'MethodNotAllowedError') {
61
- status = 405;
62
- message = err.message || 'Method Not Allowed';
63
- } else if (err.name === 'ConflictError') {
64
- status = 409;
65
- message = err.message || 'Conflict';
66
- } else if (err.name === 'UnprocessableEntityError') {
67
- status = 422;
68
- message = err.message || 'Unprocessable Entity';
69
- }
89
+ const { status, message } = getErrorInfo(err);
70
90
 
71
91
  // 构建错误响应
72
92
  const errorResponse = {
@@ -29,6 +29,18 @@ export interface LoggerOptions {
29
29
  logResponseBody?: boolean;
30
30
  }
31
31
 
32
+ /**
33
+ * 日志中间件函数
34
+ *
35
+ * 记录HTTP请求和响应信息,支持:
36
+ * - 文本和JSON两种日志格式
37
+ * - 自定义日志记录函数
38
+ * - 请求和响应体记录
39
+ * - 路径过滤
40
+ *
41
+ * @param options 日志配置选项
42
+ * @returns 返回一个中间件函数
43
+ */
32
44
  export function logger(options: LoggerOptions = {}): Middleware {
33
45
  const {
34
46
  format = 'text',
@@ -38,21 +50,39 @@ export function logger(options: LoggerOptions = {}): Middleware {
38
50
  logResponseBody = false,
39
51
  } = options;
40
52
 
41
- return async (ctx: Context, next: () => Promise<Response | Context | void>): Promise<Response | Context | void> => {
42
- // 检查是否需要忽略该路径的日志
43
- if (ignorePaths.includes(ctx.path)) {
44
- return await next();
53
+ /**
54
+ * 记录日志的通用方法
55
+ * @private
56
+ * @param message 日志消息
57
+ * @param data 日志数据对象
58
+ */
59
+ const writeLog = (message: string, data?: Record<string, any>) => {
60
+ if (format === 'json') {
61
+ logger(JSON.stringify(data || { message }));
62
+ } else {
63
+ logger(message);
45
64
  }
65
+ };
46
66
 
47
- const startTime = Date.now();
48
- const requestInfo: {
49
- method: string;
50
- path: string;
51
- ip: string;
52
- userAgent: string;
53
- query: Record<string, string>;
54
- body?: unknown;
55
- } = {
67
+ /**
68
+ * 创建请求信息对象
69
+ * @private
70
+ * @param ctx 上下文对象
71
+ * @param includeBody 是否包含请求体
72
+ * @returns 请求信息对象
73
+ */
74
+ const createRequestInfo = async (
75
+ ctx: Context,
76
+ includeBody = false,
77
+ ): Promise<{
78
+ method: string;
79
+ path: string;
80
+ ip: string;
81
+ userAgent: string;
82
+ query: Record<string, string>;
83
+ body?: unknown;
84
+ }> => {
85
+ const requestInfo = {
56
86
  method: ctx.method,
57
87
  path: ctx.path,
58
88
  ip: ctx.ip?.address || 'unknown',
@@ -60,6 +90,28 @@ export function logger(options: LoggerOptions = {}): Middleware {
60
90
  query: Object.fromEntries(ctx.url.searchParams),
61
91
  };
62
92
 
93
+ // 记录请求体(如果需要)
94
+ if (includeBody && ctx.method !== 'GET') {
95
+ try {
96
+ const body = await ctx.body();
97
+ requestInfo.body = body;
98
+ } catch (_error) {
99
+ // 忽略读取请求体错误
100
+ }
101
+ }
102
+
103
+ return requestInfo;
104
+ };
105
+
106
+ return async (ctx: Context, next: () => Promise<Response | Context | void>): Promise<Response | Context | void> => {
107
+ // 检查是否需要忽略该路径的日志
108
+ if (ignorePaths.includes(ctx.path)) {
109
+ return await next();
110
+ }
111
+
112
+ const startTime = Date.now();
113
+ const requestInfo = await createRequestInfo(ctx, logRequestBody);
114
+
63
115
  // 记录请求体(如果需要)
64
116
  if (logRequestBody && ctx.method !== 'GET') {
65
117
  try {
@@ -85,13 +137,10 @@ export function logger(options: LoggerOptions = {}): Middleware {
85
137
  stack: (error as Error).stack,
86
138
  };
87
139
 
88
- if (format === 'json') {
89
- logger(JSON.stringify(errorInfo));
90
- } else {
91
- logger(
92
- `[ERROR] ${requestInfo.method} ${requestInfo.path} ${500} ${endTime - startTime}ms - ${(error as Error).message}`,
93
- );
94
- }
140
+ writeLog(
141
+ `[ERROR] ${requestInfo.method} ${requestInfo.path} 500 ${endTime - startTime}ms - ${(error as Error).message}`,
142
+ errorInfo,
143
+ );
95
144
 
96
145
  throw error;
97
146
  } finally {
@@ -132,11 +181,10 @@ export function logger(options: LoggerOptions = {}): Middleware {
132
181
  }
133
182
  }
134
183
 
135
- if (format === 'json') {
136
- logger(JSON.stringify(responseInfo));
137
- } else {
138
- logger(`[INFO] ${requestInfo.method} ${requestInfo.path} ${response.status} ${endTime - startTime}ms`);
139
- }
184
+ writeLog(
185
+ `[INFO] ${requestInfo.method} ${requestInfo.path} ${response.status} ${endTime - startTime}ms`,
186
+ responseInfo,
187
+ );
140
188
  }
141
189
  }
142
190
  };
@@ -1,14 +1,16 @@
1
1
  import { describe, expect, it } from 'bun:test';
2
- import { cors, errorHandler, logger, staticServe, validator } from '../index';
2
+
3
3
  import type { Context } from '@thinkbun/core';
4
4
 
5
+ import { cors, errorHandler, logger, staticServe, validator } from '../index';
6
+
5
7
  // 创建模拟上下文
6
8
  function createMockContext(req: Request): Context {
7
9
  const mockApp: any = {
8
10
  server: {
9
- requestIP: () => '127.0.0.1'
11
+ requestIP: () => '127.0.0.1',
10
12
  },
11
- logger: console
13
+ logger: console,
12
14
  };
13
15
 
14
16
  const ctx: any = {
@@ -51,13 +53,17 @@ function createMockContext(req: Request): Context {
51
53
  clearCookie: () => ctx,
52
54
  body: async () => ({}),
53
55
  text: async () => '',
54
- json: (data: any, options: any = {}) => new Response(JSON.stringify(data), {
55
- status: options.status || ctx.status,
56
- headers: new Headers([...ctx.responseHeaders.entries(), ...(options.headers ? Object.entries(options.headers) : [])])
57
- }),
56
+ json: (data: any, options: any = {}) =>
57
+ new Response(JSON.stringify(data), {
58
+ status: options.status || ctx.status,
59
+ headers: new Headers([
60
+ ...ctx.responseHeaders.entries(),
61
+ ...(options.headers ? Object.entries(options.headers) : []),
62
+ ]),
63
+ }),
58
64
  ok: () => new Response(null, { status: 200, headers: ctx.responseHeaders }),
59
65
  success: (data: any) => ctx.json({ errno: 0, data }),
60
- fail: (errno: number, msg: string) => ctx.json({ errno, msg })
66
+ fail: (errno: number, msg: string) => ctx.json({ errno, msg }),
61
67
  };
62
68
 
63
69
  return ctx;
@@ -69,89 +75,89 @@ describe('Middleware Tests', () => {
69
75
  it('should add CORS headers to response', async () => {
70
76
  const req = new Request('http://test:5300/', {
71
77
  headers: {
72
- Origin: 'http://example.com'
73
- }
78
+ Origin: 'http://example.com',
79
+ },
74
80
  });
75
81
  const ctx = createMockContext(req);
76
82
  const corsMiddleware = cors();
77
-
83
+
78
84
  await corsMiddleware(ctx, async () => {});
79
-
85
+
80
86
  expect(ctx.responseHeaders.get('Access-Control-Allow-Origin')).toBe('*'); // 默认配置下使用*
81
87
  expect(ctx.responseHeaders.get('Access-Control-Allow-Methods')).toBeNull(); // 非OPTIONS请求不会添加这个头
82
88
  expect(ctx.responseHeaders.get('Access-Control-Allow-Headers')).toBeNull(); // 非OPTIONS请求不会添加这个头
83
89
  expect(ctx.responseHeaders.get('Access-Control-Max-Age')).toBeNull(); // 非OPTIONS请求不会添加这个头
84
90
  });
85
-
91
+
86
92
  it('should handle OPTIONS requests', async () => {
87
93
  const req = new Request('http://test:5300/', {
88
94
  method: 'OPTIONS',
89
95
  headers: {
90
96
  Origin: 'http://example.com',
91
97
  'Access-Control-Request-Methods': 'GET, POST',
92
- 'Access-Control-Request-Headers': 'Content-Type'
93
- }
98
+ 'Access-Control-Request-Headers': 'Content-Type',
99
+ },
94
100
  });
95
101
  const ctx = createMockContext(req);
96
102
  const corsMiddleware = cors();
97
-
103
+
98
104
  const result = await corsMiddleware(ctx, async () => {});
99
-
105
+
100
106
  expect(result).toBeDefined();
101
107
  expect((result as Response).status).toBe(204);
102
108
  });
103
109
  });
104
-
110
+
105
111
  // 测试Logger中间件
106
112
  describe('Logger Middleware', () => {
107
113
  it('should log request and response information', async () => {
108
114
  const req = new Request('http://test:5300/');
109
115
  const ctx = createMockContext(req);
110
-
116
+
111
117
  let logCalled = false;
112
118
  const loggerMiddleware = logger({
113
- logger: () => logCalled = true
119
+ logger: () => (logCalled = true),
114
120
  });
115
-
121
+
116
122
  await loggerMiddleware(ctx, async () => {
117
123
  return new Response('OK', { status: 200 });
118
124
  });
119
-
125
+
120
126
  expect(logCalled).toBe(true);
121
127
  });
122
128
  });
123
-
129
+
124
130
  // 测试ErrorHandler中间件
125
131
  describe('ErrorHandler Middleware', () => {
126
132
  it('should handle errors and return JSON response', async () => {
127
133
  const req = new Request('http://test:5300/');
128
134
  const ctx = createMockContext(req);
129
135
  const errorHandlerMiddleware = errorHandler();
130
-
136
+
131
137
  const result = await errorHandlerMiddleware(ctx, async () => {
132
138
  throw new Error('Test error');
133
139
  });
134
-
140
+
135
141
  expect(result).toBeDefined();
136
142
  const response = result as Response;
137
143
  expect(response.status).toBe(500);
138
144
  const responseData = await response.json();
139
145
  expect(responseData.error).toBe('Internal Server Error');
140
146
  });
141
-
147
+
142
148
  it('should handle HTTP errors with specific status codes', async () => {
143
149
  const req = new Request('http://test:5300/');
144
150
  const ctx = createMockContext(req);
145
151
  const errorHandlerMiddleware = errorHandler();
146
-
152
+
147
153
  // 使用与errorHandler中间件匹配的错误类型
148
154
  const notFoundError = new Error('Not Found');
149
155
  notFoundError.name = 'NotFoundError';
150
-
156
+
151
157
  const result = await errorHandlerMiddleware(ctx, async () => {
152
158
  throw notFoundError;
153
159
  });
154
-
160
+
155
161
  expect(result).toBeDefined();
156
162
  const response = result as Response;
157
163
  expect(response.status).toBe(404);
@@ -159,23 +165,23 @@ describe('Middleware Tests', () => {
159
165
  expect(responseData.error).toBe('Not Found');
160
166
  });
161
167
  });
162
-
168
+
163
169
  // 测试Validator中间件
164
170
  describe('Validator Middleware', () => {
165
171
  it('should validate request body against schema', async () => {
166
172
  const req = new Request('http://test:5300/', {
167
173
  method: 'POST',
168
174
  headers: { 'Content-Type': 'application/json' },
169
- body: JSON.stringify({ name: 'test', age: 25 })
175
+ body: JSON.stringify({ name: 'test', age: 25 }),
170
176
  });
171
177
  const ctx = createMockContext(req);
172
-
178
+
173
179
  // Mock the body method
174
180
  ctx.body = async () => ({ name: 'test', age: 25 });
175
-
181
+
176
182
  // 设置params属性
177
183
  ctx.params = {};
178
-
184
+
179
185
  const validatorMiddleware = validator({
180
186
  rules: [
181
187
  {
@@ -186,37 +192,37 @@ describe('Middleware Tests', () => {
186
192
  type: 'object',
187
193
  properties: {
188
194
  name: { type: 'string' },
189
- age: { type: 'number' }
195
+ age: { type: 'number' },
190
196
  },
191
- required: ['name', 'age']
192
- }
193
- }
194
- }
195
- ]
197
+ required: ['name', 'age'],
198
+ },
199
+ },
200
+ },
201
+ ],
196
202
  });
197
-
203
+
198
204
  let nextCalled = false;
199
205
  await validatorMiddleware(ctx, async () => {
200
206
  nextCalled = true;
201
207
  });
202
-
208
+
203
209
  expect(nextCalled).toBe(true);
204
210
  });
205
-
211
+
206
212
  it('should throw error for invalid request body', async () => {
207
213
  const req = new Request('http://test:5300/', {
208
214
  method: 'POST',
209
215
  headers: { 'Content-Type': 'application/json' },
210
- body: JSON.stringify({ name: 'test' })
216
+ body: JSON.stringify({ name: 'test' }),
211
217
  });
212
218
  const ctx = createMockContext(req);
213
-
219
+
214
220
  // Mock the body method
215
221
  ctx.body = async () => ({ name: 'test' });
216
-
222
+
217
223
  // 设置params属性
218
224
  ctx.params = {};
219
-
225
+
220
226
  const validatorMiddleware = validator({
221
227
  rules: [
222
228
  {
@@ -227,28 +233,28 @@ describe('Middleware Tests', () => {
227
233
  type: 'object',
228
234
  properties: {
229
235
  name: { type: 'string' },
230
- age: { type: 'number' }
236
+ age: { type: 'number' },
231
237
  },
232
- required: ['name', 'age']
233
- }
234
- }
235
- }
236
- ]
238
+ required: ['name', 'age'],
239
+ },
240
+ },
241
+ },
242
+ ],
237
243
  });
238
-
244
+
239
245
  await expect(validatorMiddleware(ctx, async () => {})).rejects.toThrow();
240
246
  });
241
247
  });
242
-
248
+
243
249
  // 测试StaticServe中间件
244
250
  describe('StaticServe Middleware', () => {
245
251
  it('should serve static files', async () => {
246
252
  const req = new Request('http://test:5300/test.txt');
247
253
  const ctx = createMockContext(req);
248
-
254
+
249
255
  // 这个测试会实际查找文件,所以我们只测试中间件是否能正常执行
250
256
  const staticMiddleware = staticServe({ root: './public' });
251
-
257
+
252
258
  let nextCalled = false;
253
259
  try {
254
260
  await staticMiddleware(ctx, async () => {
@@ -257,7 +263,7 @@ describe('Middleware Tests', () => {
257
263
  } catch (error) {
258
264
  // 忽略文件不存在的错误
259
265
  }
260
-
266
+
261
267
  // 因为文件可能不存在,所以无论如何都应该调用next
262
268
  expect(nextCalled).toBe(true);
263
269
  });
package/tsconfig.json CHANGED
@@ -1,7 +1,4 @@
1
1
  {
2
2
  "extends": "../../tsconfig.json",
3
- "include": [
4
- "src/**/*",
5
- "package.json"
6
- ]
3
+ "include": ["src/**/*", "package.json"],
7
4
  }
@@ -1,187 +0,0 @@
1
- import { BadRequestError, type Context, type Middleware, type RouteSchema } from '@thinkbun/core';
2
- import Ajv from 'ajv';
3
- import addFormats from 'ajv-formats';
4
-
5
- // 使用与核心框架相同的Ajv配置
6
- const ajv = addFormats(
7
- new Ajv({
8
- coerceTypes: true, // 自动类型转换
9
- useDefaults: true, // 使用默认值
10
- removeAdditional: false, // 保留额外属性
11
- allErrors: true, // 返回所有错误
12
- }),
13
- [
14
- 'date-time',
15
- 'time',
16
- 'date',
17
- 'email',
18
- 'hostname',
19
- 'ipv4',
20
- 'ipv6',
21
- 'uri',
22
- 'uri-reference',
23
- 'uuid',
24
- 'uri-template',
25
- 'json-pointer',
26
- 'relative-json-pointer',
27
- 'regex',
28
- ],
29
- );
30
-
31
- export interface ValidationRule {
32
- /**
33
- * 路由路径
34
- * @example '/api/users'
35
- */
36
- path: string;
37
- /**
38
- * HTTP方法
39
- * @example 'POST'
40
- */
41
- method: string;
42
- /**
43
- * 验证模式
44
- */
45
- schema: RouteSchema;
46
- }
47
-
48
- export interface ValidatorOptions {
49
- /**
50
- * 验证规则列表
51
- */
52
- rules: ValidationRule[];
53
- /**
54
- * 自定义错误处理函数
55
- * @param error 验证错误
56
- * @param ctx 请求上下文
57
- * @returns 返回自定义的Response对象或抛出错误使用默认处理
58
- */
59
- errorHandler?: (error: Error, ctx: Context) => Response | never;
60
- }
61
-
62
- export function validator(options: ValidatorOptions): Middleware {
63
- if (!options.rules || options.rules.length === 0) {
64
- throw new Error('Validator middleware requires at least one validation rule');
65
- }
66
-
67
- // 编译所有验证规则
68
- const compiledRules = options.rules.map((rule) => {
69
- const compiled: any = {};
70
-
71
- if (rule.schema.query) {
72
- compiled.query = ajv.compile(rule.schema.query);
73
- }
74
-
75
- if (rule.schema.body) {
76
- compiled.body = ajv.compile(rule.schema.body);
77
- }
78
-
79
- if (rule.schema.params) {
80
- compiled.params = ajv.compile(rule.schema.params);
81
- }
82
-
83
- if (rule.schema.headers) {
84
- compiled.headers = ajv.compile(rule.schema.headers);
85
- }
86
-
87
- return {
88
- ...rule,
89
- compiled,
90
- };
91
- });
92
-
93
- return async (ctx: Context, next: () => Promise<Response | Context | void>): Promise<Response | Context | void> => {
94
- // 查找匹配的验证规则
95
- const matchingRules = compiledRules.filter(
96
- (rule) => rule.path === ctx.path && rule.method.toUpperCase() === ctx.method.toUpperCase(),
97
- );
98
-
99
- if (matchingRules.length === 0) {
100
- // 没有匹配的验证规则,继续执行后续中间件
101
- return await next();
102
- }
103
-
104
- // 使用第一个匹配的规则
105
- const rule = matchingRules[0]!; // 类型断言,因为已经确保matchingRules.length > 0
106
- const errors: any[] = [];
107
-
108
- // 验证查询参数
109
- if (rule.compiled.query) {
110
- const query = Object.fromEntries(ctx.url.searchParams);
111
- if (!rule.compiled.query(query)) {
112
- errors.push(
113
- ...rule.compiled.query.errors!.map((err: any) => ({
114
- ...err,
115
- source: 'query',
116
- })),
117
- );
118
- }
119
- }
120
-
121
- // 验证路由参数
122
- if (rule.compiled.params) {
123
- if (!rule.compiled.params(ctx.params)) {
124
- errors.push(
125
- ...rule.compiled.params.errors!.map((err: any) => ({
126
- ...err,
127
- source: 'params',
128
- })),
129
- );
130
- }
131
- }
132
-
133
- // 验证请求头
134
- if (rule.compiled.headers) {
135
- const headers = Object.fromEntries(ctx.headers);
136
- if (!rule.compiled.headers(headers)) {
137
- errors.push(
138
- ...rule.compiled.headers.errors!.map((err: any) => ({
139
- ...err,
140
- source: 'headers',
141
- })),
142
- );
143
- }
144
- }
145
-
146
- // 验证请求体
147
- if (rule.compiled.body && ctx.method !== 'GET' && ctx.method !== 'HEAD') {
148
- try {
149
- const body = await ctx.body();
150
- if (!rule.compiled.body(body)) {
151
- errors.push(
152
- ...rule.compiled.body.errors!.map((err: any) => ({
153
- ...err,
154
- source: 'body',
155
- })),
156
- );
157
- }
158
- } catch (error) {
159
- // 如果无法解析请求体,添加错误信息
160
- errors.push({
161
- keyword: 'invalid_body',
162
- message: 'Invalid request body',
163
- source: 'body',
164
- });
165
- }
166
- }
167
-
168
- // 如果有验证错误
169
- if (errors.length > 0) {
170
- const error = new BadRequestError(errors);
171
-
172
- // 使用自定义错误处理(如果提供)
173
- if (options.errorHandler) {
174
- const response = options.errorHandler(error, ctx);
175
- if (response) {
176
- return response;
177
- }
178
- }
179
-
180
- // 抛出错误,由错误处理中间件处理
181
- throw error;
182
- }
183
-
184
- // 验证通过,继续执行后续中间件
185
- return await next();
186
- };
187
- }