spikard 0.3.2 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +659 -659
  4. data/ext/spikard_rb/Cargo.toml +17 -17
  5. data/ext/spikard_rb/extconf.rb +10 -10
  6. data/ext/spikard_rb/src/lib.rs +6 -6
  7. data/lib/spikard/app.rb +386 -386
  8. data/lib/spikard/background.rb +27 -27
  9. data/lib/spikard/config.rb +396 -396
  10. data/lib/spikard/converters.rb +13 -13
  11. data/lib/spikard/handler_wrapper.rb +113 -113
  12. data/lib/spikard/provide.rb +214 -214
  13. data/lib/spikard/response.rb +173 -173
  14. data/lib/spikard/schema.rb +243 -243
  15. data/lib/spikard/sse.rb +111 -111
  16. data/lib/spikard/streaming_response.rb +44 -44
  17. data/lib/spikard/testing.rb +221 -221
  18. data/lib/spikard/upload_file.rb +131 -131
  19. data/lib/spikard/version.rb +5 -5
  20. data/lib/spikard/websocket.rb +59 -59
  21. data/lib/spikard.rb +43 -43
  22. data/sig/spikard.rbs +360 -360
  23. data/vendor/crates/spikard-core/Cargo.toml +40 -40
  24. data/vendor/crates/spikard-core/src/bindings/mod.rs +3 -3
  25. data/vendor/crates/spikard-core/src/bindings/response.rs +133 -133
  26. data/vendor/crates/spikard-core/src/debug.rs +63 -63
  27. data/vendor/crates/spikard-core/src/di/container.rs +726 -726
  28. data/vendor/crates/spikard-core/src/di/dependency.rs +273 -273
  29. data/vendor/crates/spikard-core/src/di/error.rs +118 -118
  30. data/vendor/crates/spikard-core/src/di/factory.rs +538 -538
  31. data/vendor/crates/spikard-core/src/di/graph.rs +545 -545
  32. data/vendor/crates/spikard-core/src/di/mod.rs +192 -192
  33. data/vendor/crates/spikard-core/src/di/resolved.rs +411 -411
  34. data/vendor/crates/spikard-core/src/di/value.rs +283 -283
  35. data/vendor/crates/spikard-core/src/errors.rs +39 -39
  36. data/vendor/crates/spikard-core/src/http.rs +153 -153
  37. data/vendor/crates/spikard-core/src/lib.rs +29 -29
  38. data/vendor/crates/spikard-core/src/lifecycle.rs +422 -422
  39. data/vendor/crates/spikard-core/src/parameters.rs +722 -722
  40. data/vendor/crates/spikard-core/src/problem.rs +310 -310
  41. data/vendor/crates/spikard-core/src/request_data.rs +189 -189
  42. data/vendor/crates/spikard-core/src/router.rs +249 -249
  43. data/vendor/crates/spikard-core/src/schema_registry.rs +183 -183
  44. data/vendor/crates/spikard-core/src/type_hints.rs +304 -304
  45. data/vendor/crates/spikard-core/src/validation.rs +699 -699
  46. data/vendor/crates/spikard-http/Cargo.toml +58 -58
  47. data/vendor/crates/spikard-http/src/auth.rs +247 -247
  48. data/vendor/crates/spikard-http/src/background.rs +249 -249
  49. data/vendor/crates/spikard-http/src/bindings/mod.rs +3 -3
  50. data/vendor/crates/spikard-http/src/bindings/response.rs +1 -1
  51. data/vendor/crates/spikard-http/src/body_metadata.rs +8 -8
  52. data/vendor/crates/spikard-http/src/cors.rs +490 -490
  53. data/vendor/crates/spikard-http/src/debug.rs +63 -63
  54. data/vendor/crates/spikard-http/src/di_handler.rs +423 -423
  55. data/vendor/crates/spikard-http/src/handler_response.rs +190 -190
  56. data/vendor/crates/spikard-http/src/handler_trait.rs +228 -228
  57. data/vendor/crates/spikard-http/src/handler_trait_tests.rs +284 -284
  58. data/vendor/crates/spikard-http/src/lib.rs +529 -529
  59. data/vendor/crates/spikard-http/src/lifecycle/adapter.rs +149 -149
  60. data/vendor/crates/spikard-http/src/lifecycle.rs +428 -428
  61. data/vendor/crates/spikard-http/src/middleware/mod.rs +285 -285
  62. data/vendor/crates/spikard-http/src/middleware/multipart.rs +86 -86
  63. data/vendor/crates/spikard-http/src/middleware/urlencoded.rs +147 -147
  64. data/vendor/crates/spikard-http/src/middleware/validation.rs +287 -287
  65. data/vendor/crates/spikard-http/src/openapi/mod.rs +309 -309
  66. data/vendor/crates/spikard-http/src/openapi/parameter_extraction.rs +190 -190
  67. data/vendor/crates/spikard-http/src/openapi/schema_conversion.rs +308 -308
  68. data/vendor/crates/spikard-http/src/openapi/spec_generation.rs +195 -195
  69. data/vendor/crates/spikard-http/src/parameters.rs +1 -1
  70. data/vendor/crates/spikard-http/src/problem.rs +1 -1
  71. data/vendor/crates/spikard-http/src/query_parser.rs +369 -369
  72. data/vendor/crates/spikard-http/src/response.rs +399 -399
  73. data/vendor/crates/spikard-http/src/router.rs +1 -1
  74. data/vendor/crates/spikard-http/src/schema_registry.rs +1 -1
  75. data/vendor/crates/spikard-http/src/server/handler.rs +87 -87
  76. data/vendor/crates/spikard-http/src/server/lifecycle_execution.rs +98 -98
  77. data/vendor/crates/spikard-http/src/server/mod.rs +805 -805
  78. data/vendor/crates/spikard-http/src/server/request_extraction.rs +119 -119
  79. data/vendor/crates/spikard-http/src/sse.rs +447 -447
  80. data/vendor/crates/spikard-http/src/testing/form.rs +14 -14
  81. data/vendor/crates/spikard-http/src/testing/multipart.rs +60 -60
  82. data/vendor/crates/spikard-http/src/testing/test_client.rs +285 -285
  83. data/vendor/crates/spikard-http/src/testing.rs +377 -377
  84. data/vendor/crates/spikard-http/src/type_hints.rs +1 -1
  85. data/vendor/crates/spikard-http/src/validation.rs +1 -1
  86. data/vendor/crates/spikard-http/src/websocket.rs +324 -324
  87. data/vendor/crates/spikard-rb/Cargo.toml +42 -42
  88. data/vendor/crates/spikard-rb/build.rs +8 -8
  89. data/vendor/crates/spikard-rb/src/background.rs +63 -63
  90. data/vendor/crates/spikard-rb/src/config.rs +294 -294
  91. data/vendor/crates/spikard-rb/src/conversion.rs +453 -453
  92. data/vendor/crates/spikard-rb/src/di.rs +409 -409
  93. data/vendor/crates/spikard-rb/src/handler.rs +625 -625
  94. data/vendor/crates/spikard-rb/src/lib.rs +2771 -2771
  95. data/vendor/crates/spikard-rb/src/lifecycle.rs +274 -274
  96. data/vendor/crates/spikard-rb/src/server.rs +283 -283
  97. data/vendor/crates/spikard-rb/src/sse.rs +231 -231
  98. data/vendor/crates/spikard-rb/src/test_client.rs +404 -404
  99. data/vendor/crates/spikard-rb/src/test_sse.rs +143 -143
  100. data/vendor/crates/spikard-rb/src/test_websocket.rs +221 -221
  101. data/vendor/crates/spikard-rb/src/websocket.rs +233 -233
  102. data/vendor/spikard-core/Cargo.toml +40 -40
  103. data/vendor/spikard-core/src/bindings/mod.rs +3 -3
  104. data/vendor/spikard-core/src/bindings/response.rs +133 -133
  105. data/vendor/spikard-core/src/debug.rs +63 -63
  106. data/vendor/spikard-core/src/di/container.rs +726 -726
  107. data/vendor/spikard-core/src/di/dependency.rs +273 -273
  108. data/vendor/spikard-core/src/di/error.rs +118 -118
  109. data/vendor/spikard-core/src/di/factory.rs +538 -538
  110. data/vendor/spikard-core/src/di/graph.rs +545 -545
  111. data/vendor/spikard-core/src/di/mod.rs +192 -192
  112. data/vendor/spikard-core/src/di/resolved.rs +411 -411
  113. data/vendor/spikard-core/src/di/value.rs +283 -283
  114. data/vendor/spikard-core/src/http.rs +153 -153
  115. data/vendor/spikard-core/src/lib.rs +28 -28
  116. data/vendor/spikard-core/src/lifecycle.rs +422 -422
  117. data/vendor/spikard-core/src/parameters.rs +719 -719
  118. data/vendor/spikard-core/src/problem.rs +310 -310
  119. data/vendor/spikard-core/src/request_data.rs +189 -189
  120. data/vendor/spikard-core/src/router.rs +249 -249
  121. data/vendor/spikard-core/src/schema_registry.rs +183 -183
  122. data/vendor/spikard-core/src/type_hints.rs +304 -304
  123. data/vendor/spikard-core/src/validation.rs +699 -699
  124. data/vendor/spikard-http/Cargo.toml +58 -58
  125. data/vendor/spikard-http/src/auth.rs +247 -247
  126. data/vendor/spikard-http/src/background.rs +249 -249
  127. data/vendor/spikard-http/src/bindings/mod.rs +3 -3
  128. data/vendor/spikard-http/src/bindings/response.rs +1 -1
  129. data/vendor/spikard-http/src/body_metadata.rs +8 -8
  130. data/vendor/spikard-http/src/cors.rs +490 -490
  131. data/vendor/spikard-http/src/debug.rs +63 -63
  132. data/vendor/spikard-http/src/di_handler.rs +423 -423
  133. data/vendor/spikard-http/src/handler_response.rs +190 -190
  134. data/vendor/spikard-http/src/handler_trait.rs +228 -228
  135. data/vendor/spikard-http/src/handler_trait_tests.rs +284 -284
  136. data/vendor/spikard-http/src/lib.rs +529 -529
  137. data/vendor/spikard-http/src/lifecycle/adapter.rs +149 -149
  138. data/vendor/spikard-http/src/lifecycle.rs +428 -428
  139. data/vendor/spikard-http/src/middleware/mod.rs +285 -285
  140. data/vendor/spikard-http/src/middleware/multipart.rs +86 -86
  141. data/vendor/spikard-http/src/middleware/urlencoded.rs +147 -147
  142. data/vendor/spikard-http/src/middleware/validation.rs +287 -287
  143. data/vendor/spikard-http/src/openapi/mod.rs +309 -309
  144. data/vendor/spikard-http/src/openapi/parameter_extraction.rs +190 -190
  145. data/vendor/spikard-http/src/openapi/schema_conversion.rs +308 -308
  146. data/vendor/spikard-http/src/openapi/spec_generation.rs +195 -195
  147. data/vendor/spikard-http/src/parameters.rs +1 -1
  148. data/vendor/spikard-http/src/problem.rs +1 -1
  149. data/vendor/spikard-http/src/query_parser.rs +369 -369
  150. data/vendor/spikard-http/src/response.rs +399 -399
  151. data/vendor/spikard-http/src/router.rs +1 -1
  152. data/vendor/spikard-http/src/schema_registry.rs +1 -1
  153. data/vendor/spikard-http/src/server/handler.rs +80 -80
  154. data/vendor/spikard-http/src/server/lifecycle_execution.rs +98 -98
  155. data/vendor/spikard-http/src/server/mod.rs +805 -805
  156. data/vendor/spikard-http/src/server/request_extraction.rs +119 -119
  157. data/vendor/spikard-http/src/sse.rs +447 -447
  158. data/vendor/spikard-http/src/testing/form.rs +14 -14
  159. data/vendor/spikard-http/src/testing/multipart.rs +60 -60
  160. data/vendor/spikard-http/src/testing/test_client.rs +285 -285
  161. data/vendor/spikard-http/src/testing.rs +377 -377
  162. data/vendor/spikard-http/src/type_hints.rs +1 -1
  163. data/vendor/spikard-http/src/validation.rs +1 -1
  164. data/vendor/spikard-http/src/websocket.rs +324 -324
  165. data/vendor/spikard-rb/Cargo.toml +42 -42
  166. data/vendor/spikard-rb/build.rs +8 -8
  167. data/vendor/spikard-rb/src/background.rs +63 -63
  168. data/vendor/spikard-rb/src/config.rs +294 -294
  169. data/vendor/spikard-rb/src/conversion.rs +392 -392
  170. data/vendor/spikard-rb/src/di.rs +409 -409
  171. data/vendor/spikard-rb/src/handler.rs +534 -534
  172. data/vendor/spikard-rb/src/lib.rs +2020 -2020
  173. data/vendor/spikard-rb/src/lifecycle.rs +267 -267
  174. data/vendor/spikard-rb/src/server.rs +283 -283
  175. data/vendor/spikard-rb/src/sse.rs +231 -231
  176. data/vendor/spikard-rb/src/test_client.rs +404 -404
  177. data/vendor/spikard-rb/src/test_sse.rs +143 -143
  178. data/vendor/spikard-rb/src/test_websocket.rs +221 -221
  179. data/vendor/spikard-rb/src/websocket.rs +233 -233
  180. metadata +1 -1
@@ -1,490 +1,490 @@
1
- //! CORS (Cross-Origin Resource Sharing) handling
2
- //!
3
- //! Handles CORS preflight requests and adds CORS headers to responses
4
-
5
- use crate::CorsConfig;
6
- use axum::body::Body;
7
- use axum::http::{HeaderMap, HeaderValue, Response, StatusCode};
8
- use axum::response::IntoResponse;
9
-
10
- /// Check if an origin is allowed by the CORS configuration
11
- ///
12
- /// Supports exact matches and wildcard ("*") for any origin.
13
- /// Empty origins always return false for security.
14
- ///
15
- /// # Arguments
16
- /// * `origin` - The origin string from the HTTP request (e.g., "https://example.com")
17
- /// * `allowed_origins` - List of allowed origins configured for CORS
18
- ///
19
- /// # Returns
20
- /// `true` if the origin is allowed, `false` otherwise
21
- fn is_origin_allowed(origin: &str, allowed_origins: &[String]) -> bool {
22
- if origin.is_empty() {
23
- return false;
24
- }
25
-
26
- allowed_origins
27
- .iter()
28
- .any(|allowed| allowed == "*" || allowed == origin)
29
- }
30
-
31
- /// Check if a method is allowed by the CORS configuration
32
- ///
33
- /// Supports exact matches and wildcard ("*") for any method.
34
- /// Comparison is case-insensitive (e.g., "get" matches "GET").
35
- ///
36
- /// # Arguments
37
- /// * `method` - The HTTP method requested (e.g., "GET", "POST")
38
- /// * `allowed_methods` - List of allowed HTTP methods configured for CORS
39
- ///
40
- /// # Returns
41
- /// `true` if the method is allowed, `false` otherwise
42
- fn is_method_allowed(method: &str, allowed_methods: &[String]) -> bool {
43
- allowed_methods
44
- .iter()
45
- .any(|allowed| allowed == "*" || allowed.eq_ignore_ascii_case(method))
46
- }
47
-
48
- /// Check if all requested headers are allowed by CORS configuration
49
- ///
50
- /// Headers are case-insensitive. Supports wildcard ("*") for allowing any header.
51
- /// If a wildcard is configured, all requested headers are allowed.
52
- ///
53
- /// # Arguments
54
- /// * `requested` - Array of header names requested by the client
55
- /// * `allowed` - List of allowed header names configured for CORS
56
- ///
57
- /// # Returns
58
- /// `true` if all requested headers are allowed, `false` if any header is not allowed
59
- fn are_headers_allowed(requested: &[&str], allowed: &[String]) -> bool {
60
- if allowed.iter().any(|h| h == "*") {
61
- return true;
62
- }
63
-
64
- requested.iter().all(|req_header| {
65
- allowed
66
- .iter()
67
- .any(|allowed_header| allowed_header.eq_ignore_ascii_case(req_header))
68
- })
69
- }
70
-
71
- /// Handle CORS preflight (OPTIONS) request
72
- ///
73
- /// Validates the request against the CORS configuration and returns appropriate
74
- /// response or error. This function processes OPTIONS requests as defined in the
75
- /// CORS specification (RFC 7231).
76
- ///
77
- /// # Validation
78
- ///
79
- /// Checks the following conditions:
80
- /// 1. **Origin Header:** Must be present and match configured allowed origins
81
- /// 2. **Access-Control-Request-Method:** Must match configured allowed methods
82
- /// 3. **Access-Control-Request-Headers:** All requested headers must match configured allowed headers
83
- ///
84
- /// # Success Response
85
- ///
86
- /// Returns HTTP 204 (No Content) with the following response headers:
87
- /// - `Access-Control-Allow-Origin` - The origin that is allowed
88
- /// - `Access-Control-Allow-Methods` - Comma-separated list of allowed methods
89
- /// - `Access-Control-Allow-Headers` - Comma-separated list of allowed headers
90
- /// - `Access-Control-Max-Age` - Caching duration in seconds (if configured)
91
- /// - `Access-Control-Allow-Credentials` - "true" if credentials are allowed
92
- ///
93
- /// # Error Response
94
- ///
95
- /// Returns HTTP 403 (Forbidden) if validation fails for:
96
- /// - Origin not in allowed list
97
- /// - Requested method not allowed
98
- /// - Requested headers not allowed
99
- ///
100
- /// # Arguments
101
- /// * `headers` - Request headers containing CORS preflight information
102
- /// * `cors_config` - CORS configuration to validate against
103
- ///
104
- /// # Returns
105
- /// * `Ok(Response)` - 204 No Content with CORS headers
106
- /// * `Err(Response)` - 403 Forbidden or 500 Internal Server Error
107
- pub fn handle_preflight(headers: &HeaderMap, cors_config: &CorsConfig) -> Result<Response<Body>, Box<Response<Body>>> {
108
- let origin = headers.get("origin").and_then(|v| v.to_str().ok()).unwrap_or("");
109
-
110
- if origin.is_empty() || !is_origin_allowed(origin, &cors_config.allowed_origins) {
111
- return Err(Box::new(
112
- (
113
- StatusCode::FORBIDDEN,
114
- axum::Json(serde_json::json!({
115
- "detail": format!("CORS request from origin '{}' not allowed", origin)
116
- })),
117
- )
118
- .into_response(),
119
- ));
120
- }
121
-
122
- let requested_method = headers
123
- .get("access-control-request-method")
124
- .and_then(|v| v.to_str().ok())
125
- .unwrap_or("");
126
-
127
- if !requested_method.is_empty() && !is_method_allowed(requested_method, &cors_config.allowed_methods) {
128
- return Err(Box::new((StatusCode::FORBIDDEN).into_response()));
129
- }
130
-
131
- let requested_headers_str = headers
132
- .get("access-control-request-headers")
133
- .and_then(|v| v.to_str().ok());
134
-
135
- if let Some(req_headers) = requested_headers_str {
136
- let requested_headers: Vec<&str> = req_headers.split(',').map(|h| h.trim()).collect();
137
-
138
- if !are_headers_allowed(&requested_headers, &cors_config.allowed_headers) {
139
- return Err(Box::new((StatusCode::FORBIDDEN).into_response()));
140
- }
141
- }
142
-
143
- let mut response = Response::builder().status(StatusCode::NO_CONTENT);
144
-
145
- let headers_mut = match response.headers_mut() {
146
- Some(headers) => headers,
147
- None => {
148
- return Err(Box::new(
149
- (
150
- StatusCode::INTERNAL_SERVER_ERROR,
151
- axum::Json(serde_json::json!({
152
- "detail": "Failed to construct response headers"
153
- })),
154
- )
155
- .into_response(),
156
- ));
157
- }
158
- };
159
-
160
- headers_mut.insert(
161
- "access-control-allow-origin",
162
- HeaderValue::from_str(origin).unwrap_or_else(|_| HeaderValue::from_static("*")),
163
- );
164
-
165
- let methods = cors_config.allowed_methods.join(", ");
166
- headers_mut.insert(
167
- "access-control-allow-methods",
168
- HeaderValue::from_str(&methods).unwrap_or_else(|_| HeaderValue::from_static("*")),
169
- );
170
-
171
- let allowed_headers = cors_config.allowed_headers.join(", ");
172
- headers_mut.insert(
173
- "access-control-allow-headers",
174
- HeaderValue::from_str(&allowed_headers).unwrap_or_else(|_| HeaderValue::from_static("*")),
175
- );
176
-
177
- if let Some(max_age) = cors_config.max_age
178
- && let Ok(header_val) = HeaderValue::from_str(&max_age.to_string())
179
- {
180
- headers_mut.insert("access-control-max-age", header_val);
181
- }
182
-
183
- if let Some(true) = cors_config.allow_credentials {
184
- headers_mut.insert("access-control-allow-credentials", HeaderValue::from_static("true"));
185
- }
186
-
187
- match response.body(Body::empty()) {
188
- Ok(resp) => Ok(resp),
189
- Err(_) => Err(Box::new(
190
- (
191
- StatusCode::INTERNAL_SERVER_ERROR,
192
- axum::Json(serde_json::json!({
193
- "detail": "Failed to construct response body"
194
- })),
195
- )
196
- .into_response(),
197
- )),
198
- }
199
- }
200
-
201
- /// Add CORS headers to a successful response
202
- ///
203
- /// Adds appropriate CORS headers to the response based on the configuration.
204
- /// This function should be called for successful (non-error) responses to
205
- /// cross-origin requests.
206
- ///
207
- /// # Headers Added
208
- ///
209
- /// - `Access-Control-Allow-Origin` - The origin that is allowed (if valid)
210
- /// - `Access-Control-Expose-Headers` - Headers that are safe to expose to the client
211
- /// - `Access-Control-Allow-Credentials` - "true" if credentials are allowed
212
- ///
213
- /// # Arguments
214
- /// * `response` - Mutable reference to the response to modify
215
- /// * `origin` - The origin from the request (e.g., `<https://example.com>`)
216
- /// * `cors_config` - CORS configuration to apply
217
- ///
218
- /// # Example
219
- ///
220
- /// ```ignore
221
- /// let mut response = Response::new(Body::empty());
222
- /// add_cors_headers(&mut response, "https://example.com", &cors_config);
223
- /// ```
224
- pub fn add_cors_headers(response: &mut Response<Body>, origin: &str, cors_config: &CorsConfig) {
225
- let headers = response.headers_mut();
226
-
227
- if let Ok(origin_value) = HeaderValue::from_str(origin) {
228
- headers.insert("access-control-allow-origin", origin_value);
229
- }
230
-
231
- if let Some(ref expose_headers) = cors_config.expose_headers {
232
- let expose = expose_headers.join(", ");
233
- if let Ok(expose_value) = HeaderValue::from_str(&expose) {
234
- headers.insert("access-control-expose-headers", expose_value);
235
- }
236
- }
237
-
238
- if let Some(true) = cors_config.allow_credentials {
239
- headers.insert("access-control-allow-credentials", HeaderValue::from_static("true"));
240
- }
241
- }
242
-
243
- /// Validate a non-preflight CORS request
244
- ///
245
- /// Checks if the Origin header is present and allowed for non-preflight (actual) requests.
246
- /// Returns an error response if validation fails.
247
- ///
248
- /// # Validation
249
- ///
250
- /// - If no Origin header is present, the request is allowed (not a CORS request)
251
- /// - If Origin header is present, it must match the allowed origins
252
- ///
253
- /// # Arguments
254
- /// * `headers` - Request headers containing origin information
255
- /// * `cors_config` - CORS configuration to validate against
256
- ///
257
- /// # Returns
258
- /// * `Ok(())` - Request is allowed
259
- /// * `Err(Response)` - 403 Forbidden with error details
260
- ///
261
- /// # Note
262
- ///
263
- /// This function is for actual requests, not OPTIONS preflight requests.
264
- /// Use `handle_preflight` for OPTIONS requests.
265
- pub fn validate_cors_request(headers: &HeaderMap, cors_config: &CorsConfig) -> Result<(), Box<Response<Body>>> {
266
- let origin = headers.get("origin").and_then(|v| v.to_str().ok()).unwrap_or("");
267
-
268
- if !origin.is_empty() && !is_origin_allowed(origin, &cors_config.allowed_origins) {
269
- return Err(Box::new(
270
- (
271
- StatusCode::FORBIDDEN,
272
- axum::Json(serde_json::json!({
273
- "detail": format!("CORS request from origin '{}' not allowed", origin)
274
- })),
275
- )
276
- .into_response(),
277
- ));
278
- }
279
- Ok(())
280
- }
281
-
282
- #[cfg(test)]
283
- mod tests {
284
- use super::*;
285
-
286
- fn make_cors_config() -> CorsConfig {
287
- CorsConfig {
288
- allowed_origins: vec!["https://example.com".to_string()],
289
- allowed_methods: vec!["GET".to_string(), "POST".to_string()],
290
- allowed_headers: vec!["content-type".to_string(), "authorization".to_string()],
291
- expose_headers: Some(vec!["x-custom-header".to_string()]),
292
- max_age: Some(3600),
293
- allow_credentials: Some(true),
294
- }
295
- }
296
-
297
- #[test]
298
- fn test_is_origin_allowed_exact_match() {
299
- let allowed = vec!["https://example.com".to_string()];
300
- assert!(is_origin_allowed("https://example.com", &allowed));
301
- assert!(!is_origin_allowed("https://evil.com", &allowed));
302
- }
303
-
304
- #[test]
305
- fn test_is_origin_allowed_wildcard() {
306
- let allowed = vec!["*".to_string()];
307
- assert!(is_origin_allowed("https://example.com", &allowed));
308
- assert!(is_origin_allowed("https://any-domain.com", &allowed));
309
- }
310
-
311
- #[test]
312
- fn test_is_origin_allowed_empty_origin() {
313
- let allowed = vec!["*".to_string()];
314
- assert!(!is_origin_allowed("", &allowed));
315
- }
316
-
317
- #[test]
318
- fn test_is_method_allowed_case_insensitive() {
319
- let allowed = vec!["GET".to_string(), "POST".to_string()];
320
- assert!(is_method_allowed("GET", &allowed));
321
- assert!(is_method_allowed("get", &allowed));
322
- assert!(is_method_allowed("POST", &allowed));
323
- assert!(is_method_allowed("post", &allowed));
324
- assert!(!is_method_allowed("DELETE", &allowed));
325
- }
326
-
327
- #[test]
328
- fn test_is_method_allowed_wildcard() {
329
- let allowed = vec!["*".to_string()];
330
- assert!(is_method_allowed("GET", &allowed));
331
- assert!(is_method_allowed("DELETE", &allowed));
332
- assert!(is_method_allowed("PATCH", &allowed));
333
- }
334
-
335
- #[test]
336
- fn test_are_headers_allowed_case_insensitive() {
337
- let allowed = vec!["Content-Type".to_string(), "Authorization".to_string()];
338
- assert!(are_headers_allowed(&["content-type"], &allowed));
339
- assert!(are_headers_allowed(&["AUTHORIZATION"], &allowed));
340
- assert!(are_headers_allowed(&["content-type", "authorization"], &allowed));
341
- assert!(!are_headers_allowed(&["x-custom"], &allowed));
342
- }
343
-
344
- #[test]
345
- fn test_are_headers_allowed_wildcard() {
346
- let allowed = vec!["*".to_string()];
347
- assert!(are_headers_allowed(&["any-header"], &allowed));
348
- assert!(are_headers_allowed(&["multiple", "headers"], &allowed));
349
- }
350
-
351
- #[test]
352
- fn test_handle_preflight_success() {
353
- let config = make_cors_config();
354
- let mut headers = HeaderMap::new();
355
- headers.insert("origin", HeaderValue::from_static("https://example.com"));
356
- headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
357
- headers.insert(
358
- "access-control-request-headers",
359
- HeaderValue::from_static("content-type"),
360
- );
361
-
362
- let result = handle_preflight(&headers, &config);
363
- assert!(result.is_ok());
364
-
365
- let response = result.unwrap();
366
- assert_eq!(response.status(), StatusCode::NO_CONTENT);
367
-
368
- let resp_headers = response.headers();
369
- assert_eq!(
370
- resp_headers.get("access-control-allow-origin").unwrap(),
371
- "https://example.com"
372
- );
373
- assert!(
374
- resp_headers
375
- .get("access-control-allow-methods")
376
- .unwrap()
377
- .to_str()
378
- .unwrap()
379
- .contains("POST")
380
- );
381
- assert_eq!(resp_headers.get("access-control-max-age").unwrap(), "3600");
382
- assert_eq!(resp_headers.get("access-control-allow-credentials").unwrap(), "true");
383
- }
384
-
385
- #[test]
386
- fn test_handle_preflight_origin_not_allowed() {
387
- let config = make_cors_config();
388
- let mut headers = HeaderMap::new();
389
- headers.insert("origin", HeaderValue::from_static("https://evil.com"));
390
- headers.insert("access-control-request-method", HeaderValue::from_static("GET"));
391
-
392
- let result = handle_preflight(&headers, &config);
393
- assert!(result.is_err());
394
-
395
- let response = *result.unwrap_err();
396
- assert_eq!(response.status(), StatusCode::FORBIDDEN);
397
- }
398
-
399
- #[test]
400
- fn test_handle_preflight_method_not_allowed() {
401
- let config = make_cors_config();
402
- let mut headers = HeaderMap::new();
403
- headers.insert("origin", HeaderValue::from_static("https://example.com"));
404
- headers.insert("access-control-request-method", HeaderValue::from_static("DELETE"));
405
-
406
- let result = handle_preflight(&headers, &config);
407
- assert!(result.is_err());
408
-
409
- let response = *result.unwrap_err();
410
- assert_eq!(response.status(), StatusCode::FORBIDDEN);
411
- }
412
-
413
- #[test]
414
- fn test_handle_preflight_header_not_allowed() {
415
- let config = make_cors_config();
416
- let mut headers = HeaderMap::new();
417
- headers.insert("origin", HeaderValue::from_static("https://example.com"));
418
- headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
419
- headers.insert(
420
- "access-control-request-headers",
421
- HeaderValue::from_static("x-forbidden-header"),
422
- );
423
-
424
- let result = handle_preflight(&headers, &config);
425
- assert!(result.is_err());
426
-
427
- let response = *result.unwrap_err();
428
- assert_eq!(response.status(), StatusCode::FORBIDDEN);
429
- }
430
-
431
- #[test]
432
- fn test_handle_preflight_empty_origin() {
433
- let config = make_cors_config();
434
- let headers = HeaderMap::new();
435
-
436
- let result = handle_preflight(&headers, &config);
437
- assert!(result.is_err());
438
-
439
- let response = *result.unwrap_err();
440
- assert_eq!(response.status(), StatusCode::FORBIDDEN);
441
- }
442
-
443
- #[test]
444
- fn test_add_cors_headers() {
445
- let config = make_cors_config();
446
- let mut response = Response::new(Body::empty());
447
-
448
- add_cors_headers(&mut response, "https://example.com", &config);
449
-
450
- let headers = response.headers();
451
- assert_eq!(
452
- headers.get("access-control-allow-origin").unwrap(),
453
- "https://example.com"
454
- );
455
- assert_eq!(headers.get("access-control-expose-headers").unwrap(), "x-custom-header");
456
- assert_eq!(headers.get("access-control-allow-credentials").unwrap(), "true");
457
- }
458
-
459
- #[test]
460
- fn test_validate_cors_request_allowed() {
461
- let config = make_cors_config();
462
- let mut headers = HeaderMap::new();
463
- headers.insert("origin", HeaderValue::from_static("https://example.com"));
464
-
465
- let result = validate_cors_request(&headers, &config);
466
- assert!(result.is_ok());
467
- }
468
-
469
- #[test]
470
- fn test_validate_cors_request_not_allowed() {
471
- let config = make_cors_config();
472
- let mut headers = HeaderMap::new();
473
- headers.insert("origin", HeaderValue::from_static("https://evil.com"));
474
-
475
- let result = validate_cors_request(&headers, &config);
476
- assert!(result.is_err());
477
-
478
- let response = *result.unwrap_err();
479
- assert_eq!(response.status(), StatusCode::FORBIDDEN);
480
- }
481
-
482
- #[test]
483
- fn test_validate_cors_request_no_origin() {
484
- let config = make_cors_config();
485
- let headers = HeaderMap::new();
486
-
487
- let result = validate_cors_request(&headers, &config);
488
- assert!(result.is_ok());
489
- }
490
- }
1
+ //! CORS (Cross-Origin Resource Sharing) handling
2
+ //!
3
+ //! Handles CORS preflight requests and adds CORS headers to responses
4
+
5
+ use crate::CorsConfig;
6
+ use axum::body::Body;
7
+ use axum::http::{HeaderMap, HeaderValue, Response, StatusCode};
8
+ use axum::response::IntoResponse;
9
+
10
+ /// Check if an origin is allowed by the CORS configuration
11
+ ///
12
+ /// Supports exact matches and wildcard ("*") for any origin.
13
+ /// Empty origins always return false for security.
14
+ ///
15
+ /// # Arguments
16
+ /// * `origin` - The origin string from the HTTP request (e.g., "https://example.com")
17
+ /// * `allowed_origins` - List of allowed origins configured for CORS
18
+ ///
19
+ /// # Returns
20
+ /// `true` if the origin is allowed, `false` otherwise
21
+ fn is_origin_allowed(origin: &str, allowed_origins: &[String]) -> bool {
22
+ if origin.is_empty() {
23
+ return false;
24
+ }
25
+
26
+ allowed_origins
27
+ .iter()
28
+ .any(|allowed| allowed == "*" || allowed == origin)
29
+ }
30
+
31
+ /// Check if a method is allowed by the CORS configuration
32
+ ///
33
+ /// Supports exact matches and wildcard ("*") for any method.
34
+ /// Comparison is case-insensitive (e.g., "get" matches "GET").
35
+ ///
36
+ /// # Arguments
37
+ /// * `method` - The HTTP method requested (e.g., "GET", "POST")
38
+ /// * `allowed_methods` - List of allowed HTTP methods configured for CORS
39
+ ///
40
+ /// # Returns
41
+ /// `true` if the method is allowed, `false` otherwise
42
+ fn is_method_allowed(method: &str, allowed_methods: &[String]) -> bool {
43
+ allowed_methods
44
+ .iter()
45
+ .any(|allowed| allowed == "*" || allowed.eq_ignore_ascii_case(method))
46
+ }
47
+
48
+ /// Check if all requested headers are allowed by CORS configuration
49
+ ///
50
+ /// Headers are case-insensitive. Supports wildcard ("*") for allowing any header.
51
+ /// If a wildcard is configured, all requested headers are allowed.
52
+ ///
53
+ /// # Arguments
54
+ /// * `requested` - Array of header names requested by the client
55
+ /// * `allowed` - List of allowed header names configured for CORS
56
+ ///
57
+ /// # Returns
58
+ /// `true` if all requested headers are allowed, `false` if any header is not allowed
59
+ fn are_headers_allowed(requested: &[&str], allowed: &[String]) -> bool {
60
+ if allowed.iter().any(|h| h == "*") {
61
+ return true;
62
+ }
63
+
64
+ requested.iter().all(|req_header| {
65
+ allowed
66
+ .iter()
67
+ .any(|allowed_header| allowed_header.eq_ignore_ascii_case(req_header))
68
+ })
69
+ }
70
+
71
+ /// Handle CORS preflight (OPTIONS) request
72
+ ///
73
+ /// Validates the request against the CORS configuration and returns appropriate
74
+ /// response or error. This function processes OPTIONS requests as defined in the
75
+ /// CORS specification (RFC 7231).
76
+ ///
77
+ /// # Validation
78
+ ///
79
+ /// Checks the following conditions:
80
+ /// 1. **Origin Header:** Must be present and match configured allowed origins
81
+ /// 2. **Access-Control-Request-Method:** Must match configured allowed methods
82
+ /// 3. **Access-Control-Request-Headers:** All requested headers must match configured allowed headers
83
+ ///
84
+ /// # Success Response
85
+ ///
86
+ /// Returns HTTP 204 (No Content) with the following response headers:
87
+ /// - `Access-Control-Allow-Origin` - The origin that is allowed
88
+ /// - `Access-Control-Allow-Methods` - Comma-separated list of allowed methods
89
+ /// - `Access-Control-Allow-Headers` - Comma-separated list of allowed headers
90
+ /// - `Access-Control-Max-Age` - Caching duration in seconds (if configured)
91
+ /// - `Access-Control-Allow-Credentials` - "true" if credentials are allowed
92
+ ///
93
+ /// # Error Response
94
+ ///
95
+ /// Returns HTTP 403 (Forbidden) if validation fails for:
96
+ /// - Origin not in allowed list
97
+ /// - Requested method not allowed
98
+ /// - Requested headers not allowed
99
+ ///
100
+ /// # Arguments
101
+ /// * `headers` - Request headers containing CORS preflight information
102
+ /// * `cors_config` - CORS configuration to validate against
103
+ ///
104
+ /// # Returns
105
+ /// * `Ok(Response)` - 204 No Content with CORS headers
106
+ /// * `Err(Response)` - 403 Forbidden or 500 Internal Server Error
107
+ pub fn handle_preflight(headers: &HeaderMap, cors_config: &CorsConfig) -> Result<Response<Body>, Box<Response<Body>>> {
108
+ let origin = headers.get("origin").and_then(|v| v.to_str().ok()).unwrap_or("");
109
+
110
+ if origin.is_empty() || !is_origin_allowed(origin, &cors_config.allowed_origins) {
111
+ return Err(Box::new(
112
+ (
113
+ StatusCode::FORBIDDEN,
114
+ axum::Json(serde_json::json!({
115
+ "detail": format!("CORS request from origin '{}' not allowed", origin)
116
+ })),
117
+ )
118
+ .into_response(),
119
+ ));
120
+ }
121
+
122
+ let requested_method = headers
123
+ .get("access-control-request-method")
124
+ .and_then(|v| v.to_str().ok())
125
+ .unwrap_or("");
126
+
127
+ if !requested_method.is_empty() && !is_method_allowed(requested_method, &cors_config.allowed_methods) {
128
+ return Err(Box::new((StatusCode::FORBIDDEN).into_response()));
129
+ }
130
+
131
+ let requested_headers_str = headers
132
+ .get("access-control-request-headers")
133
+ .and_then(|v| v.to_str().ok());
134
+
135
+ if let Some(req_headers) = requested_headers_str {
136
+ let requested_headers: Vec<&str> = req_headers.split(',').map(|h| h.trim()).collect();
137
+
138
+ if !are_headers_allowed(&requested_headers, &cors_config.allowed_headers) {
139
+ return Err(Box::new((StatusCode::FORBIDDEN).into_response()));
140
+ }
141
+ }
142
+
143
+ let mut response = Response::builder().status(StatusCode::NO_CONTENT);
144
+
145
+ let headers_mut = match response.headers_mut() {
146
+ Some(headers) => headers,
147
+ None => {
148
+ return Err(Box::new(
149
+ (
150
+ StatusCode::INTERNAL_SERVER_ERROR,
151
+ axum::Json(serde_json::json!({
152
+ "detail": "Failed to construct response headers"
153
+ })),
154
+ )
155
+ .into_response(),
156
+ ));
157
+ }
158
+ };
159
+
160
+ headers_mut.insert(
161
+ "access-control-allow-origin",
162
+ HeaderValue::from_str(origin).unwrap_or_else(|_| HeaderValue::from_static("*")),
163
+ );
164
+
165
+ let methods = cors_config.allowed_methods.join(", ");
166
+ headers_mut.insert(
167
+ "access-control-allow-methods",
168
+ HeaderValue::from_str(&methods).unwrap_or_else(|_| HeaderValue::from_static("*")),
169
+ );
170
+
171
+ let allowed_headers = cors_config.allowed_headers.join(", ");
172
+ headers_mut.insert(
173
+ "access-control-allow-headers",
174
+ HeaderValue::from_str(&allowed_headers).unwrap_or_else(|_| HeaderValue::from_static("*")),
175
+ );
176
+
177
+ if let Some(max_age) = cors_config.max_age
178
+ && let Ok(header_val) = HeaderValue::from_str(&max_age.to_string())
179
+ {
180
+ headers_mut.insert("access-control-max-age", header_val);
181
+ }
182
+
183
+ if let Some(true) = cors_config.allow_credentials {
184
+ headers_mut.insert("access-control-allow-credentials", HeaderValue::from_static("true"));
185
+ }
186
+
187
+ match response.body(Body::empty()) {
188
+ Ok(resp) => Ok(resp),
189
+ Err(_) => Err(Box::new(
190
+ (
191
+ StatusCode::INTERNAL_SERVER_ERROR,
192
+ axum::Json(serde_json::json!({
193
+ "detail": "Failed to construct response body"
194
+ })),
195
+ )
196
+ .into_response(),
197
+ )),
198
+ }
199
+ }
200
+
201
+ /// Add CORS headers to a successful response
202
+ ///
203
+ /// Adds appropriate CORS headers to the response based on the configuration.
204
+ /// This function should be called for successful (non-error) responses to
205
+ /// cross-origin requests.
206
+ ///
207
+ /// # Headers Added
208
+ ///
209
+ /// - `Access-Control-Allow-Origin` - The origin that is allowed (if valid)
210
+ /// - `Access-Control-Expose-Headers` - Headers that are safe to expose to the client
211
+ /// - `Access-Control-Allow-Credentials` - "true" if credentials are allowed
212
+ ///
213
+ /// # Arguments
214
+ /// * `response` - Mutable reference to the response to modify
215
+ /// * `origin` - The origin from the request (e.g., `<https://example.com>`)
216
+ /// * `cors_config` - CORS configuration to apply
217
+ ///
218
+ /// # Example
219
+ ///
220
+ /// ```ignore
221
+ /// let mut response = Response::new(Body::empty());
222
+ /// add_cors_headers(&mut response, "https://example.com", &cors_config);
223
+ /// ```
224
+ pub fn add_cors_headers(response: &mut Response<Body>, origin: &str, cors_config: &CorsConfig) {
225
+ let headers = response.headers_mut();
226
+
227
+ if let Ok(origin_value) = HeaderValue::from_str(origin) {
228
+ headers.insert("access-control-allow-origin", origin_value);
229
+ }
230
+
231
+ if let Some(ref expose_headers) = cors_config.expose_headers {
232
+ let expose = expose_headers.join(", ");
233
+ if let Ok(expose_value) = HeaderValue::from_str(&expose) {
234
+ headers.insert("access-control-expose-headers", expose_value);
235
+ }
236
+ }
237
+
238
+ if let Some(true) = cors_config.allow_credentials {
239
+ headers.insert("access-control-allow-credentials", HeaderValue::from_static("true"));
240
+ }
241
+ }
242
+
243
+ /// Validate a non-preflight CORS request
244
+ ///
245
+ /// Checks if the Origin header is present and allowed for non-preflight (actual) requests.
246
+ /// Returns an error response if validation fails.
247
+ ///
248
+ /// # Validation
249
+ ///
250
+ /// - If no Origin header is present, the request is allowed (not a CORS request)
251
+ /// - If Origin header is present, it must match the allowed origins
252
+ ///
253
+ /// # Arguments
254
+ /// * `headers` - Request headers containing origin information
255
+ /// * `cors_config` - CORS configuration to validate against
256
+ ///
257
+ /// # Returns
258
+ /// * `Ok(())` - Request is allowed
259
+ /// * `Err(Response)` - 403 Forbidden with error details
260
+ ///
261
+ /// # Note
262
+ ///
263
+ /// This function is for actual requests, not OPTIONS preflight requests.
264
+ /// Use `handle_preflight` for OPTIONS requests.
265
+ pub fn validate_cors_request(headers: &HeaderMap, cors_config: &CorsConfig) -> Result<(), Box<Response<Body>>> {
266
+ let origin = headers.get("origin").and_then(|v| v.to_str().ok()).unwrap_or("");
267
+
268
+ if !origin.is_empty() && !is_origin_allowed(origin, &cors_config.allowed_origins) {
269
+ return Err(Box::new(
270
+ (
271
+ StatusCode::FORBIDDEN,
272
+ axum::Json(serde_json::json!({
273
+ "detail": format!("CORS request from origin '{}' not allowed", origin)
274
+ })),
275
+ )
276
+ .into_response(),
277
+ ));
278
+ }
279
+ Ok(())
280
+ }
281
+
282
+ #[cfg(test)]
283
+ mod tests {
284
+ use super::*;
285
+
286
+ fn make_cors_config() -> CorsConfig {
287
+ CorsConfig {
288
+ allowed_origins: vec!["https://example.com".to_string()],
289
+ allowed_methods: vec!["GET".to_string(), "POST".to_string()],
290
+ allowed_headers: vec!["content-type".to_string(), "authorization".to_string()],
291
+ expose_headers: Some(vec!["x-custom-header".to_string()]),
292
+ max_age: Some(3600),
293
+ allow_credentials: Some(true),
294
+ }
295
+ }
296
+
297
+ #[test]
298
+ fn test_is_origin_allowed_exact_match() {
299
+ let allowed = vec!["https://example.com".to_string()];
300
+ assert!(is_origin_allowed("https://example.com", &allowed));
301
+ assert!(!is_origin_allowed("https://evil.com", &allowed));
302
+ }
303
+
304
+ #[test]
305
+ fn test_is_origin_allowed_wildcard() {
306
+ let allowed = vec!["*".to_string()];
307
+ assert!(is_origin_allowed("https://example.com", &allowed));
308
+ assert!(is_origin_allowed("https://any-domain.com", &allowed));
309
+ }
310
+
311
+ #[test]
312
+ fn test_is_origin_allowed_empty_origin() {
313
+ let allowed = vec!["*".to_string()];
314
+ assert!(!is_origin_allowed("", &allowed));
315
+ }
316
+
317
+ #[test]
318
+ fn test_is_method_allowed_case_insensitive() {
319
+ let allowed = vec!["GET".to_string(), "POST".to_string()];
320
+ assert!(is_method_allowed("GET", &allowed));
321
+ assert!(is_method_allowed("get", &allowed));
322
+ assert!(is_method_allowed("POST", &allowed));
323
+ assert!(is_method_allowed("post", &allowed));
324
+ assert!(!is_method_allowed("DELETE", &allowed));
325
+ }
326
+
327
+ #[test]
328
+ fn test_is_method_allowed_wildcard() {
329
+ let allowed = vec!["*".to_string()];
330
+ assert!(is_method_allowed("GET", &allowed));
331
+ assert!(is_method_allowed("DELETE", &allowed));
332
+ assert!(is_method_allowed("PATCH", &allowed));
333
+ }
334
+
335
+ #[test]
336
+ fn test_are_headers_allowed_case_insensitive() {
337
+ let allowed = vec!["Content-Type".to_string(), "Authorization".to_string()];
338
+ assert!(are_headers_allowed(&["content-type"], &allowed));
339
+ assert!(are_headers_allowed(&["AUTHORIZATION"], &allowed));
340
+ assert!(are_headers_allowed(&["content-type", "authorization"], &allowed));
341
+ assert!(!are_headers_allowed(&["x-custom"], &allowed));
342
+ }
343
+
344
+ #[test]
345
+ fn test_are_headers_allowed_wildcard() {
346
+ let allowed = vec!["*".to_string()];
347
+ assert!(are_headers_allowed(&["any-header"], &allowed));
348
+ assert!(are_headers_allowed(&["multiple", "headers"], &allowed));
349
+ }
350
+
351
+ #[test]
352
+ fn test_handle_preflight_success() {
353
+ let config = make_cors_config();
354
+ let mut headers = HeaderMap::new();
355
+ headers.insert("origin", HeaderValue::from_static("https://example.com"));
356
+ headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
357
+ headers.insert(
358
+ "access-control-request-headers",
359
+ HeaderValue::from_static("content-type"),
360
+ );
361
+
362
+ let result = handle_preflight(&headers, &config);
363
+ assert!(result.is_ok());
364
+
365
+ let response = result.unwrap();
366
+ assert_eq!(response.status(), StatusCode::NO_CONTENT);
367
+
368
+ let resp_headers = response.headers();
369
+ assert_eq!(
370
+ resp_headers.get("access-control-allow-origin").unwrap(),
371
+ "https://example.com"
372
+ );
373
+ assert!(
374
+ resp_headers
375
+ .get("access-control-allow-methods")
376
+ .unwrap()
377
+ .to_str()
378
+ .unwrap()
379
+ .contains("POST")
380
+ );
381
+ assert_eq!(resp_headers.get("access-control-max-age").unwrap(), "3600");
382
+ assert_eq!(resp_headers.get("access-control-allow-credentials").unwrap(), "true");
383
+ }
384
+
385
+ #[test]
386
+ fn test_handle_preflight_origin_not_allowed() {
387
+ let config = make_cors_config();
388
+ let mut headers = HeaderMap::new();
389
+ headers.insert("origin", HeaderValue::from_static("https://evil.com"));
390
+ headers.insert("access-control-request-method", HeaderValue::from_static("GET"));
391
+
392
+ let result = handle_preflight(&headers, &config);
393
+ assert!(result.is_err());
394
+
395
+ let response = *result.unwrap_err();
396
+ assert_eq!(response.status(), StatusCode::FORBIDDEN);
397
+ }
398
+
399
+ #[test]
400
+ fn test_handle_preflight_method_not_allowed() {
401
+ let config = make_cors_config();
402
+ let mut headers = HeaderMap::new();
403
+ headers.insert("origin", HeaderValue::from_static("https://example.com"));
404
+ headers.insert("access-control-request-method", HeaderValue::from_static("DELETE"));
405
+
406
+ let result = handle_preflight(&headers, &config);
407
+ assert!(result.is_err());
408
+
409
+ let response = *result.unwrap_err();
410
+ assert_eq!(response.status(), StatusCode::FORBIDDEN);
411
+ }
412
+
413
+ #[test]
414
+ fn test_handle_preflight_header_not_allowed() {
415
+ let config = make_cors_config();
416
+ let mut headers = HeaderMap::new();
417
+ headers.insert("origin", HeaderValue::from_static("https://example.com"));
418
+ headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
419
+ headers.insert(
420
+ "access-control-request-headers",
421
+ HeaderValue::from_static("x-forbidden-header"),
422
+ );
423
+
424
+ let result = handle_preflight(&headers, &config);
425
+ assert!(result.is_err());
426
+
427
+ let response = *result.unwrap_err();
428
+ assert_eq!(response.status(), StatusCode::FORBIDDEN);
429
+ }
430
+
431
+ #[test]
432
+ fn test_handle_preflight_empty_origin() {
433
+ let config = make_cors_config();
434
+ let headers = HeaderMap::new();
435
+
436
+ let result = handle_preflight(&headers, &config);
437
+ assert!(result.is_err());
438
+
439
+ let response = *result.unwrap_err();
440
+ assert_eq!(response.status(), StatusCode::FORBIDDEN);
441
+ }
442
+
443
+ #[test]
444
+ fn test_add_cors_headers() {
445
+ let config = make_cors_config();
446
+ let mut response = Response::new(Body::empty());
447
+
448
+ add_cors_headers(&mut response, "https://example.com", &config);
449
+
450
+ let headers = response.headers();
451
+ assert_eq!(
452
+ headers.get("access-control-allow-origin").unwrap(),
453
+ "https://example.com"
454
+ );
455
+ assert_eq!(headers.get("access-control-expose-headers").unwrap(), "x-custom-header");
456
+ assert_eq!(headers.get("access-control-allow-credentials").unwrap(), "true");
457
+ }
458
+
459
+ #[test]
460
+ fn test_validate_cors_request_allowed() {
461
+ let config = make_cors_config();
462
+ let mut headers = HeaderMap::new();
463
+ headers.insert("origin", HeaderValue::from_static("https://example.com"));
464
+
465
+ let result = validate_cors_request(&headers, &config);
466
+ assert!(result.is_ok());
467
+ }
468
+
469
+ #[test]
470
+ fn test_validate_cors_request_not_allowed() {
471
+ let config = make_cors_config();
472
+ let mut headers = HeaderMap::new();
473
+ headers.insert("origin", HeaderValue::from_static("https://evil.com"));
474
+
475
+ let result = validate_cors_request(&headers, &config);
476
+ assert!(result.is_err());
477
+
478
+ let response = *result.unwrap_err();
479
+ assert_eq!(response.status(), StatusCode::FORBIDDEN);
480
+ }
481
+
482
+ #[test]
483
+ fn test_validate_cors_request_no_origin() {
484
+ let config = make_cors_config();
485
+ let headers = HeaderMap::new();
486
+
487
+ let result = validate_cors_request(&headers, &config);
488
+ assert!(result.is_ok());
489
+ }
490
+ }