spikard 0.5.0 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (135) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +674 -674
  4. data/ext/spikard_rb/Cargo.toml +17 -17
  5. data/ext/spikard_rb/extconf.rb +13 -10
  6. data/ext/spikard_rb/src/lib.rs +6 -6
  7. data/lib/spikard/app.rb +405 -405
  8. data/lib/spikard/background.rb +27 -27
  9. data/lib/spikard/config.rb +396 -396
  10. data/lib/spikard/converters.rb +13 -13
  11. data/lib/spikard/handler_wrapper.rb +113 -113
  12. data/lib/spikard/provide.rb +214 -214
  13. data/lib/spikard/response.rb +173 -173
  14. data/lib/spikard/schema.rb +243 -243
  15. data/lib/spikard/sse.rb +111 -111
  16. data/lib/spikard/streaming_response.rb +44 -44
  17. data/lib/spikard/testing.rb +256 -256
  18. data/lib/spikard/upload_file.rb +131 -131
  19. data/lib/spikard/version.rb +5 -5
  20. data/lib/spikard/websocket.rb +59 -59
  21. data/lib/spikard.rb +43 -43
  22. data/sig/spikard.rbs +366 -366
  23. data/vendor/crates/spikard-bindings-shared/Cargo.toml +63 -63
  24. data/vendor/crates/spikard-bindings-shared/examples/config_extraction.rs +132 -132
  25. data/vendor/crates/spikard-bindings-shared/src/config_extractor.rs +752 -752
  26. data/vendor/crates/spikard-bindings-shared/src/conversion_traits.rs +194 -194
  27. data/vendor/crates/spikard-bindings-shared/src/di_traits.rs +246 -246
  28. data/vendor/crates/spikard-bindings-shared/src/error_response.rs +401 -401
  29. data/vendor/crates/spikard-bindings-shared/src/handler_base.rs +238 -238
  30. data/vendor/crates/spikard-bindings-shared/src/lib.rs +24 -24
  31. data/vendor/crates/spikard-bindings-shared/src/lifecycle_base.rs +292 -292
  32. data/vendor/crates/spikard-bindings-shared/src/lifecycle_executor.rs +616 -616
  33. data/vendor/crates/spikard-bindings-shared/src/response_builder.rs +305 -305
  34. data/vendor/crates/spikard-bindings-shared/src/test_client_base.rs +248 -248
  35. data/vendor/crates/spikard-bindings-shared/src/validation_helpers.rs +351 -351
  36. data/vendor/crates/spikard-bindings-shared/tests/comprehensive_coverage.rs +454 -454
  37. data/vendor/crates/spikard-bindings-shared/tests/error_response_edge_cases.rs +383 -383
  38. data/vendor/crates/spikard-bindings-shared/tests/handler_base_integration.rs +280 -280
  39. data/vendor/crates/spikard-core/Cargo.toml +40 -40
  40. data/vendor/crates/spikard-core/src/bindings/mod.rs +3 -3
  41. data/vendor/crates/spikard-core/src/bindings/response.rs +133 -133
  42. data/vendor/crates/spikard-core/src/debug.rs +127 -127
  43. data/vendor/crates/spikard-core/src/di/container.rs +702 -702
  44. data/vendor/crates/spikard-core/src/di/dependency.rs +273 -273
  45. data/vendor/crates/spikard-core/src/di/error.rs +118 -118
  46. data/vendor/crates/spikard-core/src/di/factory.rs +534 -534
  47. data/vendor/crates/spikard-core/src/di/graph.rs +506 -506
  48. data/vendor/crates/spikard-core/src/di/mod.rs +192 -192
  49. data/vendor/crates/spikard-core/src/di/resolved.rs +405 -405
  50. data/vendor/crates/spikard-core/src/di/value.rs +281 -281
  51. data/vendor/crates/spikard-core/src/errors.rs +69 -69
  52. data/vendor/crates/spikard-core/src/http.rs +415 -415
  53. data/vendor/crates/spikard-core/src/lib.rs +29 -29
  54. data/vendor/crates/spikard-core/src/lifecycle.rs +1186 -1186
  55. data/vendor/crates/spikard-core/src/metadata.rs +389 -389
  56. data/vendor/crates/spikard-core/src/parameters.rs +2525 -2525
  57. data/vendor/crates/spikard-core/src/problem.rs +344 -344
  58. data/vendor/crates/spikard-core/src/request_data.rs +1154 -1154
  59. data/vendor/crates/spikard-core/src/router.rs +510 -510
  60. data/vendor/crates/spikard-core/src/schema_registry.rs +183 -183
  61. data/vendor/crates/spikard-core/src/type_hints.rs +304 -304
  62. data/vendor/crates/spikard-core/src/validation/error_mapper.rs +696 -688
  63. data/vendor/crates/spikard-core/src/validation/mod.rs +457 -457
  64. data/vendor/crates/spikard-http/Cargo.toml +62 -64
  65. data/vendor/crates/spikard-http/examples/sse-notifications.rs +148 -148
  66. data/vendor/crates/spikard-http/examples/websocket-chat.rs +92 -92
  67. data/vendor/crates/spikard-http/src/auth.rs +296 -296
  68. data/vendor/crates/spikard-http/src/background.rs +1860 -1860
  69. data/vendor/crates/spikard-http/src/bindings/mod.rs +3 -3
  70. data/vendor/crates/spikard-http/src/bindings/response.rs +1 -1
  71. data/vendor/crates/spikard-http/src/body_metadata.rs +8 -8
  72. data/vendor/crates/spikard-http/src/cors.rs +1005 -1005
  73. data/vendor/crates/spikard-http/src/debug.rs +128 -128
  74. data/vendor/crates/spikard-http/src/di_handler.rs +1668 -1668
  75. data/vendor/crates/spikard-http/src/handler_response.rs +901 -901
  76. data/vendor/crates/spikard-http/src/handler_trait.rs +838 -830
  77. data/vendor/crates/spikard-http/src/handler_trait_tests.rs +290 -290
  78. data/vendor/crates/spikard-http/src/lib.rs +534 -534
  79. data/vendor/crates/spikard-http/src/lifecycle/adapter.rs +230 -230
  80. data/vendor/crates/spikard-http/src/lifecycle.rs +1193 -1193
  81. data/vendor/crates/spikard-http/src/middleware/mod.rs +560 -540
  82. data/vendor/crates/spikard-http/src/middleware/multipart.rs +912 -912
  83. data/vendor/crates/spikard-http/src/middleware/urlencoded.rs +513 -513
  84. data/vendor/crates/spikard-http/src/middleware/validation.rs +768 -735
  85. data/vendor/crates/spikard-http/src/openapi/mod.rs +309 -309
  86. data/vendor/crates/spikard-http/src/openapi/parameter_extraction.rs +535 -535
  87. data/vendor/crates/spikard-http/src/openapi/schema_conversion.rs +1363 -1363
  88. data/vendor/crates/spikard-http/src/openapi/spec_generation.rs +665 -665
  89. data/vendor/crates/spikard-http/src/query_parser.rs +793 -793
  90. data/vendor/crates/spikard-http/src/response.rs +720 -720
  91. data/vendor/crates/spikard-http/src/server/handler.rs +1650 -1650
  92. data/vendor/crates/spikard-http/src/server/lifecycle_execution.rs +234 -234
  93. data/vendor/crates/spikard-http/src/server/mod.rs +1593 -1502
  94. data/vendor/crates/spikard-http/src/server/request_extraction.rs +789 -770
  95. data/vendor/crates/spikard-http/src/server/routing_factory.rs +629 -599
  96. data/vendor/crates/spikard-http/src/sse.rs +1409 -1409
  97. data/vendor/crates/spikard-http/src/testing/form.rs +52 -52
  98. data/vendor/crates/spikard-http/src/testing/multipart.rs +64 -60
  99. data/vendor/crates/spikard-http/src/testing/test_client.rs +311 -283
  100. data/vendor/crates/spikard-http/src/testing.rs +406 -377
  101. data/vendor/crates/spikard-http/src/websocket.rs +1404 -1375
  102. data/vendor/crates/spikard-http/tests/background_behavior.rs +832 -832
  103. data/vendor/crates/spikard-http/tests/common/handlers.rs +309 -309
  104. data/vendor/crates/spikard-http/tests/common/mod.rs +26 -26
  105. data/vendor/crates/spikard-http/tests/di_integration.rs +192 -192
  106. data/vendor/crates/spikard-http/tests/doc_snippets.rs +5 -5
  107. data/vendor/crates/spikard-http/tests/lifecycle_execution.rs +1093 -1093
  108. data/vendor/crates/spikard-http/tests/multipart_behavior.rs +656 -656
  109. data/vendor/crates/spikard-http/tests/server_config_builder.rs +314 -314
  110. data/vendor/crates/spikard-http/tests/sse_behavior.rs +620 -620
  111. data/vendor/crates/spikard-http/tests/websocket_behavior.rs +663 -663
  112. data/vendor/crates/spikard-rb/Cargo.toml +48 -48
  113. data/vendor/crates/spikard-rb/build.rs +199 -199
  114. data/vendor/crates/spikard-rb/src/background.rs +63 -63
  115. data/vendor/crates/spikard-rb/src/config/mod.rs +5 -5
  116. data/vendor/crates/spikard-rb/src/config/server_config.rs +285 -285
  117. data/vendor/crates/spikard-rb/src/conversion.rs +554 -554
  118. data/vendor/crates/spikard-rb/src/di/builder.rs +100 -100
  119. data/vendor/crates/spikard-rb/src/di/mod.rs +375 -375
  120. data/vendor/crates/spikard-rb/src/handler.rs +618 -618
  121. data/vendor/crates/spikard-rb/src/integration/mod.rs +3 -3
  122. data/vendor/crates/spikard-rb/src/lib.rs +1806 -1810
  123. data/vendor/crates/spikard-rb/src/lifecycle.rs +275 -275
  124. data/vendor/crates/spikard-rb/src/metadata/mod.rs +5 -5
  125. data/vendor/crates/spikard-rb/src/metadata/route_extraction.rs +442 -447
  126. data/vendor/crates/spikard-rb/src/runtime/mod.rs +5 -5
  127. data/vendor/crates/spikard-rb/src/runtime/server_runner.rs +324 -324
  128. data/vendor/crates/spikard-rb/src/server.rs +305 -308
  129. data/vendor/crates/spikard-rb/src/sse.rs +231 -231
  130. data/vendor/crates/spikard-rb/src/testing/client.rs +538 -551
  131. data/vendor/crates/spikard-rb/src/testing/mod.rs +7 -7
  132. data/vendor/crates/spikard-rb/src/testing/sse.rs +143 -143
  133. data/vendor/crates/spikard-rb/src/testing/websocket.rs +608 -635
  134. data/vendor/crates/spikard-rb/src/websocket.rs +377 -374
  135. metadata +15 -1
@@ -1,1005 +1,1005 @@
1
- //! CORS (Cross-Origin Resource Sharing) handling
2
- //!
3
- //! Handles CORS preflight requests and adds CORS headers to responses
4
-
5
- use crate::CorsConfig;
6
- use axum::body::Body;
7
- use axum::http::{HeaderMap, HeaderValue, Response, StatusCode};
8
- use axum::response::IntoResponse;
9
-
10
- /// Check if an origin is allowed by the CORS configuration
11
- ///
12
- /// Supports exact matches and wildcard ("*") for any origin.
13
- /// Empty origins always return false for security.
14
- ///
15
- /// # Arguments
16
- /// * `origin` - The origin string from the HTTP request (e.g., "https://example.com")
17
- /// * `allowed_origins` - List of allowed origins configured for CORS
18
- ///
19
- /// # Returns
20
- /// `true` if the origin is allowed, `false` otherwise
21
- fn is_origin_allowed(origin: &str, allowed_origins: &[String]) -> bool {
22
- if origin.is_empty() {
23
- return false;
24
- }
25
-
26
- allowed_origins
27
- .iter()
28
- .any(|allowed| allowed == "*" || allowed == origin)
29
- }
30
-
31
- /// Check if a method is allowed by the CORS configuration
32
- ///
33
- /// Supports exact matches and wildcard ("*") for any method.
34
- /// Comparison is case-insensitive (e.g., "get" matches "GET").
35
- ///
36
- /// # Arguments
37
- /// * `method` - The HTTP method requested (e.g., "GET", "POST")
38
- /// * `allowed_methods` - List of allowed HTTP methods configured for CORS
39
- ///
40
- /// # Returns
41
- /// `true` if the method is allowed, `false` otherwise
42
- fn is_method_allowed(method: &str, allowed_methods: &[String]) -> bool {
43
- allowed_methods
44
- .iter()
45
- .any(|allowed| allowed == "*" || allowed.eq_ignore_ascii_case(method))
46
- }
47
-
48
- /// Check if all requested headers are allowed by CORS configuration
49
- ///
50
- /// Headers are case-insensitive. Supports wildcard ("*") for allowing any header.
51
- /// If a wildcard is configured, all requested headers are allowed.
52
- ///
53
- /// # Arguments
54
- /// * `requested` - Array of header names requested by the client
55
- /// * `allowed` - List of allowed header names configured for CORS
56
- ///
57
- /// # Returns
58
- /// `true` if all requested headers are allowed, `false` if any header is not allowed
59
- fn are_headers_allowed(requested: &[&str], allowed: &[String]) -> bool {
60
- if allowed.iter().any(|h| h == "*") {
61
- return true;
62
- }
63
-
64
- requested.iter().all(|req_header| {
65
- allowed
66
- .iter()
67
- .any(|allowed_header| allowed_header.eq_ignore_ascii_case(req_header))
68
- })
69
- }
70
-
71
- /// Handle CORS preflight (OPTIONS) request
72
- ///
73
- /// Validates the request against the CORS configuration and returns appropriate
74
- /// response or error. This function processes OPTIONS requests as defined in the
75
- /// CORS specification (RFC 7231).
76
- ///
77
- /// # Validation
78
- ///
79
- /// Checks the following conditions:
80
- /// 1. **Origin Header:** Must be present and match configured allowed origins
81
- /// 2. **Access-Control-Request-Method:** Must match configured allowed methods
82
- /// 3. **Access-Control-Request-Headers:** All requested headers must match configured allowed headers
83
- ///
84
- /// # Success Response
85
- ///
86
- /// Returns HTTP 204 (No Content) with the following response headers:
87
- /// - `Access-Control-Allow-Origin` - The origin that is allowed
88
- /// - `Access-Control-Allow-Methods` - Comma-separated list of allowed methods
89
- /// - `Access-Control-Allow-Headers` - Comma-separated list of allowed headers
90
- /// - `Access-Control-Max-Age` - Caching duration in seconds (if configured)
91
- /// - `Access-Control-Allow-Credentials` - "true" if credentials are allowed
92
- ///
93
- /// # Error Response
94
- ///
95
- /// Returns HTTP 403 (Forbidden) if validation fails for:
96
- /// - Origin not in allowed list
97
- /// - Requested method not allowed
98
- /// - Requested headers not allowed
99
- ///
100
- /// # Arguments
101
- /// * `headers` - Request headers containing CORS preflight information
102
- /// * `cors_config` - CORS configuration to validate against
103
- ///
104
- /// # Returns
105
- /// * `Ok(Response)` - 204 No Content with CORS headers
106
- /// * `Err(Response)` - 403 Forbidden or 500 Internal Server Error
107
- pub fn handle_preflight(headers: &HeaderMap, cors_config: &CorsConfig) -> Result<Response<Body>, Box<Response<Body>>> {
108
- let origin = headers.get("origin").and_then(|v| v.to_str().ok()).unwrap_or("");
109
-
110
- if origin.is_empty() || !is_origin_allowed(origin, &cors_config.allowed_origins) {
111
- return Err(Box::new(
112
- (
113
- StatusCode::FORBIDDEN,
114
- axum::Json(serde_json::json!({
115
- "detail": format!("CORS request from origin '{}' not allowed", origin)
116
- })),
117
- )
118
- .into_response(),
119
- ));
120
- }
121
-
122
- let requested_method = headers
123
- .get("access-control-request-method")
124
- .and_then(|v| v.to_str().ok())
125
- .unwrap_or("");
126
-
127
- if !requested_method.is_empty() && !is_method_allowed(requested_method, &cors_config.allowed_methods) {
128
- return Err(Box::new((StatusCode::FORBIDDEN).into_response()));
129
- }
130
-
131
- let requested_headers_str = headers
132
- .get("access-control-request-headers")
133
- .and_then(|v| v.to_str().ok());
134
-
135
- if let Some(req_headers) = requested_headers_str {
136
- let requested_headers: Vec<&str> = req_headers.split(',').map(|h| h.trim()).collect();
137
-
138
- if !are_headers_allowed(&requested_headers, &cors_config.allowed_headers) {
139
- return Err(Box::new((StatusCode::FORBIDDEN).into_response()));
140
- }
141
- }
142
-
143
- let mut response = Response::builder().status(StatusCode::NO_CONTENT);
144
-
145
- let headers_mut = match response.headers_mut() {
146
- Some(headers) => headers,
147
- None => {
148
- return Err(Box::new(
149
- (
150
- StatusCode::INTERNAL_SERVER_ERROR,
151
- axum::Json(serde_json::json!({
152
- "detail": "Failed to construct response headers"
153
- })),
154
- )
155
- .into_response(),
156
- ));
157
- }
158
- };
159
-
160
- headers_mut.insert(
161
- "access-control-allow-origin",
162
- HeaderValue::from_str(origin).unwrap_or_else(|_| HeaderValue::from_static("*")),
163
- );
164
-
165
- let methods = cors_config.allowed_methods.join(", ");
166
- headers_mut.insert(
167
- "access-control-allow-methods",
168
- HeaderValue::from_str(&methods).unwrap_or_else(|_| HeaderValue::from_static("*")),
169
- );
170
-
171
- let allowed_headers = cors_config.allowed_headers.join(", ");
172
- headers_mut.insert(
173
- "access-control-allow-headers",
174
- HeaderValue::from_str(&allowed_headers).unwrap_or_else(|_| HeaderValue::from_static("*")),
175
- );
176
-
177
- if let Some(max_age) = cors_config.max_age
178
- && let Ok(header_val) = HeaderValue::from_str(&max_age.to_string())
179
- {
180
- headers_mut.insert("access-control-max-age", header_val);
181
- }
182
-
183
- if let Some(true) = cors_config.allow_credentials {
184
- headers_mut.insert("access-control-allow-credentials", HeaderValue::from_static("true"));
185
- }
186
-
187
- match response.body(Body::empty()) {
188
- Ok(resp) => Ok(resp),
189
- Err(_) => Err(Box::new(
190
- (
191
- StatusCode::INTERNAL_SERVER_ERROR,
192
- axum::Json(serde_json::json!({
193
- "detail": "Failed to construct response body"
194
- })),
195
- )
196
- .into_response(),
197
- )),
198
- }
199
- }
200
-
201
- /// Add CORS headers to a successful response
202
- ///
203
- /// Adds appropriate CORS headers to the response based on the configuration.
204
- /// This function should be called for successful (non-error) responses to
205
- /// cross-origin requests.
206
- ///
207
- /// # Headers Added
208
- ///
209
- /// - `Access-Control-Allow-Origin` - The origin that is allowed (if valid)
210
- /// - `Access-Control-Expose-Headers` - Headers that are safe to expose to the client
211
- /// - `Access-Control-Allow-Credentials` - "true" if credentials are allowed
212
- ///
213
- /// # Arguments
214
- /// * `response` - Mutable reference to the response to modify
215
- /// * `origin` - The origin from the request (e.g., `<https://example.com>`)
216
- /// * `cors_config` - CORS configuration to apply
217
- ///
218
- /// # Example
219
- ///
220
- /// ```ignore
221
- /// let mut response = Response::new(Body::empty());
222
- /// add_cors_headers(&mut response, "https://example.com", &cors_config);
223
- /// ```
224
- pub fn add_cors_headers(response: &mut Response<Body>, origin: &str, cors_config: &CorsConfig) {
225
- let headers = response.headers_mut();
226
-
227
- if let Ok(origin_value) = HeaderValue::from_str(origin) {
228
- headers.insert("access-control-allow-origin", origin_value);
229
- }
230
-
231
- if let Some(ref expose_headers) = cors_config.expose_headers {
232
- let expose = expose_headers.join(", ");
233
- if let Ok(expose_value) = HeaderValue::from_str(&expose) {
234
- headers.insert("access-control-expose-headers", expose_value);
235
- }
236
- }
237
-
238
- if let Some(true) = cors_config.allow_credentials {
239
- headers.insert("access-control-allow-credentials", HeaderValue::from_static("true"));
240
- }
241
- }
242
-
243
- /// Validate a non-preflight CORS request
244
- ///
245
- /// Checks if the Origin header is present and allowed for non-preflight (actual) requests.
246
- /// Returns an error response if validation fails.
247
- ///
248
- /// # Validation
249
- ///
250
- /// - If no Origin header is present, the request is allowed (not a CORS request)
251
- /// - If Origin header is present, it must match the allowed origins
252
- ///
253
- /// # Arguments
254
- /// * `headers` - Request headers containing origin information
255
- /// * `cors_config` - CORS configuration to validate against
256
- ///
257
- /// # Returns
258
- /// * `Ok(())` - Request is allowed
259
- /// * `Err(Response)` - 403 Forbidden with error details
260
- ///
261
- /// # Note
262
- ///
263
- /// This function is for actual requests, not OPTIONS preflight requests.
264
- /// Use `handle_preflight` for OPTIONS requests.
265
- pub fn validate_cors_request(headers: &HeaderMap, cors_config: &CorsConfig) -> Result<(), Box<Response<Body>>> {
266
- let origin = headers.get("origin").and_then(|v| v.to_str().ok()).unwrap_or("");
267
-
268
- if !origin.is_empty() && !is_origin_allowed(origin, &cors_config.allowed_origins) {
269
- return Err(Box::new(
270
- (
271
- StatusCode::FORBIDDEN,
272
- axum::Json(serde_json::json!({
273
- "detail": format!("CORS request from origin '{}' not allowed", origin)
274
- })),
275
- )
276
- .into_response(),
277
- ));
278
- }
279
- Ok(())
280
- }
281
-
282
- #[cfg(test)]
283
- mod tests {
284
- use super::*;
285
-
286
- fn make_cors_config() -> CorsConfig {
287
- CorsConfig {
288
- allowed_origins: vec!["https://example.com".to_string()],
289
- allowed_methods: vec!["GET".to_string(), "POST".to_string()],
290
- allowed_headers: vec!["content-type".to_string(), "authorization".to_string()],
291
- expose_headers: Some(vec!["x-custom-header".to_string()]),
292
- max_age: Some(3600),
293
- allow_credentials: Some(true),
294
- }
295
- }
296
-
297
- #[test]
298
- fn test_is_origin_allowed_exact_match() {
299
- let allowed = vec!["https://example.com".to_string()];
300
- assert!(is_origin_allowed("https://example.com", &allowed));
301
- assert!(!is_origin_allowed("https://evil.com", &allowed));
302
- }
303
-
304
- #[test]
305
- fn test_is_origin_allowed_wildcard() {
306
- let allowed = vec!["*".to_string()];
307
- assert!(is_origin_allowed("https://example.com", &allowed));
308
- assert!(is_origin_allowed("https://any-domain.com", &allowed));
309
- }
310
-
311
- #[test]
312
- fn test_is_origin_allowed_empty_origin() {
313
- let allowed = vec!["*".to_string()];
314
- assert!(!is_origin_allowed("", &allowed));
315
- }
316
-
317
- #[test]
318
- fn test_is_method_allowed_case_insensitive() {
319
- let allowed = vec!["GET".to_string(), "POST".to_string()];
320
- assert!(is_method_allowed("GET", &allowed));
321
- assert!(is_method_allowed("get", &allowed));
322
- assert!(is_method_allowed("POST", &allowed));
323
- assert!(is_method_allowed("post", &allowed));
324
- assert!(!is_method_allowed("DELETE", &allowed));
325
- }
326
-
327
- #[test]
328
- fn test_is_method_allowed_wildcard() {
329
- let allowed = vec!["*".to_string()];
330
- assert!(is_method_allowed("GET", &allowed));
331
- assert!(is_method_allowed("DELETE", &allowed));
332
- assert!(is_method_allowed("PATCH", &allowed));
333
- }
334
-
335
- #[test]
336
- fn test_are_headers_allowed_case_insensitive() {
337
- let allowed = vec!["Content-Type".to_string(), "Authorization".to_string()];
338
- assert!(are_headers_allowed(&["content-type"], &allowed));
339
- assert!(are_headers_allowed(&["AUTHORIZATION"], &allowed));
340
- assert!(are_headers_allowed(&["content-type", "authorization"], &allowed));
341
- assert!(!are_headers_allowed(&["x-custom"], &allowed));
342
- }
343
-
344
- #[test]
345
- fn test_are_headers_allowed_wildcard() {
346
- let allowed = vec!["*".to_string()];
347
- assert!(are_headers_allowed(&["any-header"], &allowed));
348
- assert!(are_headers_allowed(&["multiple", "headers"], &allowed));
349
- }
350
-
351
- #[test]
352
- fn test_handle_preflight_success() {
353
- let config = make_cors_config();
354
- let mut headers = HeaderMap::new();
355
- headers.insert("origin", HeaderValue::from_static("https://example.com"));
356
- headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
357
- headers.insert(
358
- "access-control-request-headers",
359
- HeaderValue::from_static("content-type"),
360
- );
361
-
362
- let result = handle_preflight(&headers, &config);
363
- assert!(result.is_ok());
364
-
365
- let response = result.unwrap();
366
- assert_eq!(response.status(), StatusCode::NO_CONTENT);
367
-
368
- let resp_headers = response.headers();
369
- assert_eq!(
370
- resp_headers.get("access-control-allow-origin").unwrap(),
371
- "https://example.com"
372
- );
373
- assert!(
374
- resp_headers
375
- .get("access-control-allow-methods")
376
- .unwrap()
377
- .to_str()
378
- .unwrap()
379
- .contains("POST")
380
- );
381
- assert_eq!(resp_headers.get("access-control-max-age").unwrap(), "3600");
382
- assert_eq!(resp_headers.get("access-control-allow-credentials").unwrap(), "true");
383
- }
384
-
385
- #[test]
386
- fn test_handle_preflight_origin_not_allowed() {
387
- let config = make_cors_config();
388
- let mut headers = HeaderMap::new();
389
- headers.insert("origin", HeaderValue::from_static("https://evil.com"));
390
- headers.insert("access-control-request-method", HeaderValue::from_static("GET"));
391
-
392
- let result = handle_preflight(&headers, &config);
393
- assert!(result.is_err());
394
-
395
- let response = *result.unwrap_err();
396
- assert_eq!(response.status(), StatusCode::FORBIDDEN);
397
- }
398
-
399
- #[test]
400
- fn test_handle_preflight_method_not_allowed() {
401
- let config = make_cors_config();
402
- let mut headers = HeaderMap::new();
403
- headers.insert("origin", HeaderValue::from_static("https://example.com"));
404
- headers.insert("access-control-request-method", HeaderValue::from_static("DELETE"));
405
-
406
- let result = handle_preflight(&headers, &config);
407
- assert!(result.is_err());
408
-
409
- let response = *result.unwrap_err();
410
- assert_eq!(response.status(), StatusCode::FORBIDDEN);
411
- }
412
-
413
- #[test]
414
- fn test_handle_preflight_header_not_allowed() {
415
- let config = make_cors_config();
416
- let mut headers = HeaderMap::new();
417
- headers.insert("origin", HeaderValue::from_static("https://example.com"));
418
- headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
419
- headers.insert(
420
- "access-control-request-headers",
421
- HeaderValue::from_static("x-forbidden-header"),
422
- );
423
-
424
- let result = handle_preflight(&headers, &config);
425
- assert!(result.is_err());
426
-
427
- let response = *result.unwrap_err();
428
- assert_eq!(response.status(), StatusCode::FORBIDDEN);
429
- }
430
-
431
- #[test]
432
- fn test_handle_preflight_empty_origin() {
433
- let config = make_cors_config();
434
- let headers = HeaderMap::new();
435
-
436
- let result = handle_preflight(&headers, &config);
437
- assert!(result.is_err());
438
-
439
- let response = *result.unwrap_err();
440
- assert_eq!(response.status(), StatusCode::FORBIDDEN);
441
- }
442
-
443
- #[test]
444
- fn test_add_cors_headers() {
445
- let config = make_cors_config();
446
- let mut response = Response::new(Body::empty());
447
-
448
- add_cors_headers(&mut response, "https://example.com", &config);
449
-
450
- let headers = response.headers();
451
- assert_eq!(
452
- headers.get("access-control-allow-origin").unwrap(),
453
- "https://example.com"
454
- );
455
- assert_eq!(headers.get("access-control-expose-headers").unwrap(), "x-custom-header");
456
- assert_eq!(headers.get("access-control-allow-credentials").unwrap(), "true");
457
- }
458
-
459
- #[test]
460
- fn test_validate_cors_request_allowed() {
461
- let config = make_cors_config();
462
- let mut headers = HeaderMap::new();
463
- headers.insert("origin", HeaderValue::from_static("https://example.com"));
464
-
465
- let result = validate_cors_request(&headers, &config);
466
- assert!(result.is_ok());
467
- }
468
-
469
- #[test]
470
- fn test_validate_cors_request_not_allowed() {
471
- let config = make_cors_config();
472
- let mut headers = HeaderMap::new();
473
- headers.insert("origin", HeaderValue::from_static("https://evil.com"));
474
-
475
- let result = validate_cors_request(&headers, &config);
476
- assert!(result.is_err());
477
-
478
- let response = *result.unwrap_err();
479
- assert_eq!(response.status(), StatusCode::FORBIDDEN);
480
- }
481
-
482
- #[test]
483
- fn test_validate_cors_request_no_origin() {
484
- let config = make_cors_config();
485
- let headers = HeaderMap::new();
486
-
487
- let result = validate_cors_request(&headers, &config);
488
- assert!(result.is_ok());
489
- }
490
-
491
- // SECURITY TESTS: CORS Attack Vectors
492
-
493
- /// SECURITY TEST: Verify credentials=true with wildcard is caught
494
- /// This is a critical vulnerability - RFC 6454 forbids this
495
- #[test]
496
- fn test_credentials_with_wildcard_config_is_security_issue() {
497
- let config = CorsConfig {
498
- allowed_origins: vec!["*".to_string()],
499
- allowed_methods: vec!["GET".to_string()],
500
- allowed_headers: vec![],
501
- expose_headers: None,
502
- max_age: None,
503
- allow_credentials: Some(true), // SECURITY BUG: This should not be allowed with wildcard
504
- };
505
-
506
- let mut headers = HeaderMap::new();
507
- headers.insert("origin", HeaderValue::from_static("https://evil.com"));
508
- headers.insert("access-control-request-method", HeaderValue::from_static("GET"));
509
-
510
- let result = handle_preflight(&headers, &config);
511
-
512
- // BUG: This should return 500 or reject the config, but instead succeeds
513
- if let Ok(response) = result {
514
- let resp_headers = response.headers();
515
- let has_credentials = resp_headers
516
- .get("access-control-allow-credentials")
517
- .map(|v| v == "true")
518
- .unwrap_or(false);
519
- let origin_header = resp_headers.get("access-control-allow-origin");
520
-
521
- if has_credentials && origin_header.is_some() {
522
- let origin_val = origin_header.unwrap().to_str().unwrap_or("");
523
- if origin_val == "*" {
524
- panic!("SECURITY VULNERABILITY: credentials=true with origin=* allowed");
525
- }
526
- }
527
- }
528
- }
529
-
530
- /// SECURITY TEST: Exact origin matching required
531
- /// Subdomain like api.evil.example.com must NOT match example.com
532
- #[test]
533
- fn test_subdomain_bypass_blocked() {
534
- let config = CorsConfig {
535
- allowed_origins: vec!["https://example.com".to_string()],
536
- allowed_methods: vec!["GET".to_string()],
537
- allowed_headers: vec![],
538
- expose_headers: None,
539
- max_age: None,
540
- allow_credentials: None,
541
- };
542
-
543
- assert!(!is_origin_allowed("https://api.example.com", &config.allowed_origins));
544
- assert!(!is_origin_allowed("https://evil.example.com", &config.allowed_origins));
545
- assert!(!is_origin_allowed(
546
- "https://sub.sub.example.com",
547
- &config.allowed_origins
548
- ));
549
-
550
- assert!(is_origin_allowed("https://example.com", &config.allowed_origins));
551
- }
552
-
553
- /// SECURITY TEST: Port exact matching required
554
- /// localhost:3001 must NOT match localhost:3000
555
- #[test]
556
- fn test_port_bypass_blocked() {
557
- let config = CorsConfig {
558
- allowed_origins: vec!["http://localhost:3000".to_string()],
559
- allowed_methods: vec!["GET".to_string()],
560
- allowed_headers: vec![],
561
- expose_headers: None,
562
- max_age: None,
563
- allow_credentials: None,
564
- };
565
-
566
- assert!(!is_origin_allowed("http://localhost:3001", &config.allowed_origins));
567
- assert!(!is_origin_allowed("http://localhost:8080", &config.allowed_origins));
568
- assert!(!is_origin_allowed("http://localhost:443", &config.allowed_origins));
569
-
570
- assert!(is_origin_allowed("http://localhost:3000", &config.allowed_origins));
571
- }
572
-
573
- /// SECURITY TEST: Protocol exact matching required
574
- /// http://example.com must NOT match https://example.com
575
- #[test]
576
- fn test_protocol_downgrade_attack_blocked() {
577
- let config = CorsConfig {
578
- allowed_origins: vec!["https://example.com".to_string()],
579
- allowed_methods: vec!["GET".to_string()],
580
- allowed_headers: vec![],
581
- expose_headers: None,
582
- max_age: None,
583
- allow_credentials: None,
584
- };
585
-
586
- assert!(!is_origin_allowed("http://example.com", &config.allowed_origins));
587
- assert!(!is_origin_allowed("ws://example.com", &config.allowed_origins));
588
- assert!(!is_origin_allowed("wss://example.com", &config.allowed_origins));
589
-
590
- assert!(is_origin_allowed("https://example.com", &config.allowed_origins));
591
- }
592
-
593
- /// SECURITY TEST: Case sensitivity in origin matching
594
- /// Origins should match exactly (including case)
595
- #[test]
596
- fn test_case_sensitive_origin_matching() {
597
- let config = CorsConfig {
598
- allowed_origins: vec!["https://Example.Com".to_string()],
599
- allowed_methods: vec!["GET".to_string()],
600
- allowed_headers: vec![],
601
- expose_headers: None,
602
- max_age: None,
603
- allow_credentials: None,
604
- };
605
-
606
- assert!(!is_origin_allowed("https://example.com", &config.allowed_origins));
607
- assert!(!is_origin_allowed("https://EXAMPLE.COM", &config.allowed_origins));
608
-
609
- assert!(is_origin_allowed("https://Example.Com", &config.allowed_origins));
610
- }
611
-
612
- /// SECURITY TEST: Trailing slash normalization
613
- /// https://example.com/ should be treated differently from https://example.com
614
- #[test]
615
- fn test_trailing_slash_origin_not_normalized() {
616
- let config = CorsConfig {
617
- allowed_origins: vec!["https://example.com".to_string()],
618
- allowed_methods: vec!["GET".to_string()],
619
- allowed_headers: vec![],
620
- expose_headers: None,
621
- max_age: None,
622
- allow_credentials: None,
623
- };
624
-
625
- assert!(!is_origin_allowed("https://example.com/", &config.allowed_origins));
626
-
627
- assert!(is_origin_allowed("https://example.com", &config.allowed_origins));
628
- }
629
-
630
- /// SECURITY TEST: NULL origin and wildcard behavior
631
- /// Special "null" origin used by file:// and sandboxed iframes
632
- /// The current implementation treats "null" as a regular origin string,
633
- /// which means it IS allowed by wildcard (not ideal but documents current behavior)
634
- #[test]
635
- fn test_null_origin_with_wildcard() {
636
- let config = CorsConfig {
637
- allowed_origins: vec!["*".to_string()],
638
- allowed_methods: vec!["GET".to_string()],
639
- allowed_headers: vec![],
640
- expose_headers: None,
641
- max_age: None,
642
- allow_credentials: None,
643
- };
644
-
645
- // SECURITY NOTE: "null" origin is allowed by wildcard in current implementation
646
- assert!(is_origin_allowed("null", &config.allowed_origins));
647
-
648
- let with_explicit_null = CorsConfig {
649
- allowed_origins: vec!["null".to_string()],
650
- allowed_methods: vec!["GET".to_string()],
651
- allowed_headers: vec![],
652
- expose_headers: None,
653
- max_age: None,
654
- allow_credentials: None,
655
- };
656
- assert!(is_origin_allowed("null", &with_explicit_null.allowed_origins));
657
- }
658
-
659
- /// SECURITY TEST: Empty origin is always rejected
660
- #[test]
661
- fn test_empty_origin_always_rejected() {
662
- let config_with_wildcard = CorsConfig {
663
- allowed_origins: vec!["*".to_string()],
664
- allowed_methods: vec!["GET".to_string()],
665
- allowed_headers: vec![],
666
- expose_headers: None,
667
- max_age: None,
668
- allow_credentials: None,
669
- };
670
- assert!(!is_origin_allowed("", &config_with_wildcard.allowed_origins));
671
-
672
- let config_with_explicit = CorsConfig {
673
- allowed_origins: vec!["https://example.com".to_string()],
674
- allowed_methods: vec!["GET".to_string()],
675
- allowed_headers: vec![],
676
- expose_headers: None,
677
- max_age: None,
678
- allow_credentials: None,
679
- };
680
- assert!(!is_origin_allowed("", &config_with_explicit.allowed_origins));
681
- }
682
-
683
- /// SECURITY TEST: Preflight with invalid origin should reject
684
- #[test]
685
- fn test_preflight_rejects_invalid_origin() {
686
- let config = make_cors_config();
687
- let mut headers = HeaderMap::new();
688
- headers.insert("origin", HeaderValue::from_static("https://untrusted.com"));
689
- headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
690
-
691
- let result = handle_preflight(&headers, &config);
692
- assert!(result.is_err());
693
-
694
- let response = *result.unwrap_err();
695
- assert_eq!(response.status(), StatusCode::FORBIDDEN);
696
- }
697
-
698
- /// SECURITY TEST: Multiple origins - each must be exact match
699
- #[test]
700
- fn test_multiple_origins_exact_matching() {
701
- let config = CorsConfig {
702
- allowed_origins: vec!["https://trusted1.com".to_string(), "https://trusted2.com".to_string()],
703
- allowed_methods: vec!["GET".to_string()],
704
- allowed_headers: vec![],
705
- expose_headers: None,
706
- max_age: None,
707
- allow_credentials: None,
708
- };
709
-
710
- assert!(is_origin_allowed("https://trusted1.com", &config.allowed_origins));
711
- assert!(is_origin_allowed("https://trusted2.com", &config.allowed_origins));
712
-
713
- assert!(!is_origin_allowed(
714
- "https://trusted1.com.evil.com",
715
- &config.allowed_origins
716
- ));
717
- assert!(!is_origin_allowed("https://trusted3.com", &config.allowed_origins));
718
- assert!(!is_origin_allowed("https://trusted.com", &config.allowed_origins));
719
- }
720
-
721
- /// SECURITY TEST: Wildcard origin should allow any origin (but check config)
722
- #[test]
723
- fn test_wildcard_allows_all_but_check_credentials() {
724
- let config = CorsConfig {
725
- allowed_origins: vec!["*".to_string()],
726
- allowed_methods: vec!["GET".to_string()],
727
- allowed_headers: vec![],
728
- expose_headers: None,
729
- max_age: None,
730
- allow_credentials: None,
731
- };
732
-
733
- assert!(is_origin_allowed("https://example.com", &config.allowed_origins));
734
- assert!(is_origin_allowed("https://evil.com", &config.allowed_origins));
735
- assert!(is_origin_allowed("http://localhost:3000", &config.allowed_origins));
736
-
737
- assert!(!is_origin_allowed("", &config.allowed_origins));
738
- }
739
-
740
- /// SECURITY TEST: Preflight response headers must match config exactly
741
- #[test]
742
- fn test_preflight_response_has_correct_allowed_origins() {
743
- let config = CorsConfig {
744
- allowed_origins: vec!["https://trusted.com".to_string()],
745
- allowed_methods: vec!["GET".to_string(), "POST".to_string()],
746
- allowed_headers: vec!["content-type".to_string()],
747
- expose_headers: None,
748
- max_age: Some(3600),
749
- allow_credentials: Some(false),
750
- };
751
-
752
- let mut headers = HeaderMap::new();
753
- headers.insert("origin", HeaderValue::from_static("https://trusted.com"));
754
- headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
755
- headers.insert(
756
- "access-control-request-headers",
757
- HeaderValue::from_static("content-type"),
758
- );
759
-
760
- let result = handle_preflight(&headers, &config);
761
- assert!(result.is_ok());
762
-
763
- let response = result.unwrap();
764
- let resp_headers = response.headers();
765
-
766
- assert_eq!(
767
- resp_headers.get("access-control-allow-origin").unwrap(),
768
- "https://trusted.com"
769
- );
770
-
771
- assert!(
772
- resp_headers
773
- .get("access-control-allow-methods")
774
- .unwrap()
775
- .to_str()
776
- .unwrap()
777
- .contains("GET")
778
- );
779
- assert!(
780
- resp_headers
781
- .get("access-control-allow-methods")
782
- .unwrap()
783
- .to_str()
784
- .unwrap()
785
- .contains("POST")
786
- );
787
-
788
- assert!(resp_headers.get("access-control-allow-credentials").is_none());
789
- }
790
-
791
- /// SECURITY TEST: Origin not in allowed list must be rejected in preflight
792
- #[test]
793
- fn test_preflight_all_origins_require_validation() {
794
- let config = CorsConfig {
795
- allowed_origins: vec!["https://trusted.com".to_string()],
796
- allowed_methods: vec!["GET".to_string()],
797
- allowed_headers: vec![],
798
- expose_headers: None,
799
- max_age: None,
800
- allow_credentials: None,
801
- };
802
-
803
- let test_cases = vec![
804
- "https://trusted.com",
805
- "https://evil.com",
806
- "https://trusted.com.evil",
807
- "http://trusted.com",
808
- "",
809
- ];
810
-
811
- for origin in test_cases {
812
- let mut headers = HeaderMap::new();
813
- headers.insert(
814
- "origin",
815
- HeaderValue::from_str(origin).unwrap_or_else(|_| HeaderValue::from_static("")),
816
- );
817
- headers.insert("access-control-request-method", HeaderValue::from_static("GET"));
818
-
819
- let result = handle_preflight(&headers, &config);
820
-
821
- if origin == "https://trusted.com" {
822
- assert!(result.is_ok(), "Valid origin {} should be allowed", origin);
823
- } else {
824
- assert!(result.is_err(), "Invalid origin {} should be rejected", origin);
825
- }
826
- }
827
- }
828
-
829
- /// SECURITY TEST: Requested headers must be in allowed list
830
- #[test]
831
- fn test_preflight_validates_all_requested_headers() {
832
- let config = CorsConfig {
833
- allowed_origins: vec!["https://trusted.com".to_string()],
834
- allowed_methods: vec!["POST".to_string()],
835
- allowed_headers: vec!["content-type".to_string(), "authorization".to_string()],
836
- expose_headers: None,
837
- max_age: None,
838
- allow_credentials: None,
839
- };
840
-
841
- let test_cases = vec![
842
- ("content-type", true),
843
- ("authorization", true),
844
- ("content-type, authorization", true),
845
- ("x-custom-header", false),
846
- ("content-type, x-custom", false),
847
- ];
848
-
849
- for (headers_str, should_pass) in test_cases {
850
- let mut headers = HeaderMap::new();
851
- headers.insert("origin", HeaderValue::from_static("https://trusted.com"));
852
- headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
853
- headers.insert(
854
- "access-control-request-headers",
855
- HeaderValue::from_str(headers_str).unwrap(),
856
- );
857
-
858
- let result = handle_preflight(&headers, &config);
859
-
860
- if should_pass {
861
- assert!(
862
- result.is_ok(),
863
- "Preflight with valid headers '{}' should pass",
864
- headers_str
865
- );
866
- } else {
867
- assert!(
868
- result.is_err(),
869
- "Preflight with invalid headers '{}' should fail",
870
- headers_str
871
- );
872
- }
873
- }
874
- }
875
-
876
- /// SECURITY TEST: add_cors_headers should respect origin validation
877
- #[test]
878
- fn test_add_cors_headers_respects_origin() {
879
- let config = CorsConfig {
880
- allowed_origins: vec!["https://trusted.com".to_string()],
881
- allowed_methods: vec!["GET".to_string()],
882
- allowed_headers: vec![],
883
- expose_headers: Some(vec!["x-custom".to_string()]),
884
- max_age: None,
885
- allow_credentials: Some(true),
886
- };
887
-
888
- let mut response = Response::new(Body::empty());
889
-
890
- add_cors_headers(&mut response, "https://trusted.com", &config);
891
-
892
- let headers = response.headers();
893
- assert_eq!(
894
- headers.get("access-control-allow-origin").unwrap(),
895
- "https://trusted.com"
896
- );
897
- assert_eq!(headers.get("access-control-expose-headers").unwrap(), "x-custom");
898
- assert_eq!(headers.get("access-control-allow-credentials").unwrap(), "true");
899
- }
900
-
901
- /// SECURITY TEST: validate_cors_request respects allowed origins
902
- #[test]
903
- fn test_validate_cors_request_origin_must_match() {
904
- let config = CorsConfig {
905
- allowed_origins: vec!["https://trusted.com".to_string()],
906
- allowed_methods: vec!["GET".to_string()],
907
- allowed_headers: vec![],
908
- expose_headers: None,
909
- max_age: None,
910
- allow_credentials: None,
911
- };
912
-
913
- let mut headers = HeaderMap::new();
914
- headers.insert("origin", HeaderValue::from_static("https://trusted.com"));
915
- assert!(validate_cors_request(&headers, &config).is_ok());
916
-
917
- let mut headers = HeaderMap::new();
918
- headers.insert("origin", HeaderValue::from_static("https://evil.com"));
919
- assert!(validate_cors_request(&headers, &config).is_err());
920
-
921
- let headers = HeaderMap::new();
922
- assert!(validate_cors_request(&headers, &config).is_ok());
923
- }
924
-
925
- /// SECURITY TEST: Preflight without requested method should fail
926
- #[test]
927
- fn test_preflight_requires_access_control_request_method() {
928
- let config = make_cors_config();
929
- let mut headers = HeaderMap::new();
930
- headers.insert("origin", HeaderValue::from_static("https://example.com"));
931
-
932
- let result = handle_preflight(&headers, &config);
933
- assert!(result.is_ok());
934
- }
935
-
936
- /// SECURITY TEST: Case-insensitive method matching
937
- #[test]
938
- fn test_preflight_method_case_insensitive() {
939
- let config = CorsConfig {
940
- allowed_origins: vec!["https://example.com".to_string()],
941
- allowed_methods: vec!["GET".to_string(), "POST".to_string()],
942
- allowed_headers: vec![],
943
- expose_headers: None,
944
- max_age: None,
945
- allow_credentials: None,
946
- };
947
-
948
- let test_cases = vec!["GET", "get", "Get", "POST", "post"];
949
-
950
- for method in test_cases {
951
- let mut headers = HeaderMap::new();
952
- headers.insert("origin", HeaderValue::from_static("https://example.com"));
953
- headers.insert("access-control-request-method", HeaderValue::from_str(method).unwrap());
954
-
955
- let result = handle_preflight(&headers, &config);
956
- assert!(
957
- result.is_ok(),
958
- "Method '{}' should be allowed (case-insensitive)",
959
- method
960
- );
961
- }
962
- }
963
-
964
- /// SECURITY TEST: Ensure preflight max-age is set correctly
965
- #[test]
966
- fn test_preflight_max_age_header() {
967
- let config = CorsConfig {
968
- allowed_origins: vec!["https://example.com".to_string()],
969
- allowed_methods: vec!["GET".to_string()],
970
- allowed_headers: vec![],
971
- expose_headers: None,
972
- max_age: Some(7200),
973
- allow_credentials: None,
974
- };
975
-
976
- let mut headers = HeaderMap::new();
977
- headers.insert("origin", HeaderValue::from_static("https://example.com"));
978
- headers.insert("access-control-request-method", HeaderValue::from_static("GET"));
979
-
980
- let result = handle_preflight(&headers, &config);
981
- assert!(result.is_ok());
982
-
983
- let response = result.unwrap();
984
- assert_eq!(response.headers().get("access-control-max-age").unwrap(), "7200");
985
- }
986
-
987
- /// SECURITY TEST: Wildcard partial patterns should not work
988
- /// *.example.com style patterns are not supported (good!)
989
- #[test]
990
- fn test_wildcard_patterns_not_supported() {
991
- let config = CorsConfig {
992
- allowed_origins: vec!["*.example.com".to_string()],
993
- allowed_methods: vec!["GET".to_string()],
994
- allowed_headers: vec![],
995
- expose_headers: None,
996
- max_age: None,
997
- allow_credentials: None,
998
- };
999
-
1000
- assert!(!is_origin_allowed("https://api.example.com", &config.allowed_origins));
1001
- assert!(!is_origin_allowed("https://example.com", &config.allowed_origins));
1002
-
1003
- assert!(is_origin_allowed("*.example.com", &config.allowed_origins));
1004
- }
1005
- }
1
+ //! CORS (Cross-Origin Resource Sharing) handling
2
+ //!
3
+ //! Handles CORS preflight requests and adds CORS headers to responses
4
+
5
+ use crate::CorsConfig;
6
+ use axum::body::Body;
7
+ use axum::http::{HeaderMap, HeaderValue, Response, StatusCode};
8
+ use axum::response::IntoResponse;
9
+
10
+ /// Check if an origin is allowed by the CORS configuration
11
+ ///
12
+ /// Supports exact matches and wildcard ("*") for any origin.
13
+ /// Empty origins always return false for security.
14
+ ///
15
+ /// # Arguments
16
+ /// * `origin` - The origin string from the HTTP request (e.g., "https://example.com")
17
+ /// * `allowed_origins` - List of allowed origins configured for CORS
18
+ ///
19
+ /// # Returns
20
+ /// `true` if the origin is allowed, `false` otherwise
21
+ fn is_origin_allowed(origin: &str, allowed_origins: &[String]) -> bool {
22
+ if origin.is_empty() {
23
+ return false;
24
+ }
25
+
26
+ allowed_origins
27
+ .iter()
28
+ .any(|allowed| allowed == "*" || allowed == origin)
29
+ }
30
+
31
+ /// Check if a method is allowed by the CORS configuration
32
+ ///
33
+ /// Supports exact matches and wildcard ("*") for any method.
34
+ /// Comparison is case-insensitive (e.g., "get" matches "GET").
35
+ ///
36
+ /// # Arguments
37
+ /// * `method` - The HTTP method requested (e.g., "GET", "POST")
38
+ /// * `allowed_methods` - List of allowed HTTP methods configured for CORS
39
+ ///
40
+ /// # Returns
41
+ /// `true` if the method is allowed, `false` otherwise
42
+ fn is_method_allowed(method: &str, allowed_methods: &[String]) -> bool {
43
+ allowed_methods
44
+ .iter()
45
+ .any(|allowed| allowed == "*" || allowed.eq_ignore_ascii_case(method))
46
+ }
47
+
48
+ /// Check if all requested headers are allowed by CORS configuration
49
+ ///
50
+ /// Headers are case-insensitive. Supports wildcard ("*") for allowing any header.
51
+ /// If a wildcard is configured, all requested headers are allowed.
52
+ ///
53
+ /// # Arguments
54
+ /// * `requested` - Array of header names requested by the client
55
+ /// * `allowed` - List of allowed header names configured for CORS
56
+ ///
57
+ /// # Returns
58
+ /// `true` if all requested headers are allowed, `false` if any header is not allowed
59
+ fn are_headers_allowed(requested: &[&str], allowed: &[String]) -> bool {
60
+ if allowed.iter().any(|h| h == "*") {
61
+ return true;
62
+ }
63
+
64
+ requested.iter().all(|req_header| {
65
+ allowed
66
+ .iter()
67
+ .any(|allowed_header| allowed_header.eq_ignore_ascii_case(req_header))
68
+ })
69
+ }
70
+
71
+ /// Handle CORS preflight (OPTIONS) request
72
+ ///
73
+ /// Validates the request against the CORS configuration and returns appropriate
74
+ /// response or error. This function processes OPTIONS requests as defined in the
75
+ /// CORS specification (RFC 7231).
76
+ ///
77
+ /// # Validation
78
+ ///
79
+ /// Checks the following conditions:
80
+ /// 1. **Origin Header:** Must be present and match configured allowed origins
81
+ /// 2. **Access-Control-Request-Method:** Must match configured allowed methods
82
+ /// 3. **Access-Control-Request-Headers:** All requested headers must match configured allowed headers
83
+ ///
84
+ /// # Success Response
85
+ ///
86
+ /// Returns HTTP 204 (No Content) with the following response headers:
87
+ /// - `Access-Control-Allow-Origin` - The origin that is allowed
88
+ /// - `Access-Control-Allow-Methods` - Comma-separated list of allowed methods
89
+ /// - `Access-Control-Allow-Headers` - Comma-separated list of allowed headers
90
+ /// - `Access-Control-Max-Age` - Caching duration in seconds (if configured)
91
+ /// - `Access-Control-Allow-Credentials` - "true" if credentials are allowed
92
+ ///
93
+ /// # Error Response
94
+ ///
95
+ /// Returns HTTP 403 (Forbidden) if validation fails for:
96
+ /// - Origin not in allowed list
97
+ /// - Requested method not allowed
98
+ /// - Requested headers not allowed
99
+ ///
100
+ /// # Arguments
101
+ /// * `headers` - Request headers containing CORS preflight information
102
+ /// * `cors_config` - CORS configuration to validate against
103
+ ///
104
+ /// # Returns
105
+ /// * `Ok(Response)` - 204 No Content with CORS headers
106
+ /// * `Err(Response)` - 403 Forbidden or 500 Internal Server Error
107
+ pub fn handle_preflight(headers: &HeaderMap, cors_config: &CorsConfig) -> Result<Response<Body>, Box<Response<Body>>> {
108
+ let origin = headers.get("origin").and_then(|v| v.to_str().ok()).unwrap_or("");
109
+
110
+ if origin.is_empty() || !is_origin_allowed(origin, &cors_config.allowed_origins) {
111
+ return Err(Box::new(
112
+ (
113
+ StatusCode::FORBIDDEN,
114
+ axum::Json(serde_json::json!({
115
+ "detail": format!("CORS request from origin '{}' not allowed", origin)
116
+ })),
117
+ )
118
+ .into_response(),
119
+ ));
120
+ }
121
+
122
+ let requested_method = headers
123
+ .get("access-control-request-method")
124
+ .and_then(|v| v.to_str().ok())
125
+ .unwrap_or("");
126
+
127
+ if !requested_method.is_empty() && !is_method_allowed(requested_method, &cors_config.allowed_methods) {
128
+ return Err(Box::new((StatusCode::FORBIDDEN).into_response()));
129
+ }
130
+
131
+ let requested_headers_str = headers
132
+ .get("access-control-request-headers")
133
+ .and_then(|v| v.to_str().ok());
134
+
135
+ if let Some(req_headers) = requested_headers_str {
136
+ let requested_headers: Vec<&str> = req_headers.split(',').map(|h| h.trim()).collect();
137
+
138
+ if !are_headers_allowed(&requested_headers, &cors_config.allowed_headers) {
139
+ return Err(Box::new((StatusCode::FORBIDDEN).into_response()));
140
+ }
141
+ }
142
+
143
+ let mut response = Response::builder().status(StatusCode::NO_CONTENT);
144
+
145
+ let headers_mut = match response.headers_mut() {
146
+ Some(headers) => headers,
147
+ None => {
148
+ return Err(Box::new(
149
+ (
150
+ StatusCode::INTERNAL_SERVER_ERROR,
151
+ axum::Json(serde_json::json!({
152
+ "detail": "Failed to construct response headers"
153
+ })),
154
+ )
155
+ .into_response(),
156
+ ));
157
+ }
158
+ };
159
+
160
+ headers_mut.insert(
161
+ "access-control-allow-origin",
162
+ HeaderValue::from_str(origin).unwrap_or_else(|_| HeaderValue::from_static("*")),
163
+ );
164
+
165
+ let methods = cors_config.allowed_methods.join(", ");
166
+ headers_mut.insert(
167
+ "access-control-allow-methods",
168
+ HeaderValue::from_str(&methods).unwrap_or_else(|_| HeaderValue::from_static("*")),
169
+ );
170
+
171
+ let allowed_headers = cors_config.allowed_headers.join(", ");
172
+ headers_mut.insert(
173
+ "access-control-allow-headers",
174
+ HeaderValue::from_str(&allowed_headers).unwrap_or_else(|_| HeaderValue::from_static("*")),
175
+ );
176
+
177
+ if let Some(max_age) = cors_config.max_age
178
+ && let Ok(header_val) = HeaderValue::from_str(&max_age.to_string())
179
+ {
180
+ headers_mut.insert("access-control-max-age", header_val);
181
+ }
182
+
183
+ if let Some(true) = cors_config.allow_credentials {
184
+ headers_mut.insert("access-control-allow-credentials", HeaderValue::from_static("true"));
185
+ }
186
+
187
+ match response.body(Body::empty()) {
188
+ Ok(resp) => Ok(resp),
189
+ Err(_) => Err(Box::new(
190
+ (
191
+ StatusCode::INTERNAL_SERVER_ERROR,
192
+ axum::Json(serde_json::json!({
193
+ "detail": "Failed to construct response body"
194
+ })),
195
+ )
196
+ .into_response(),
197
+ )),
198
+ }
199
+ }
200
+
201
+ /// Add CORS headers to a successful response
202
+ ///
203
+ /// Adds appropriate CORS headers to the response based on the configuration.
204
+ /// This function should be called for successful (non-error) responses to
205
+ /// cross-origin requests.
206
+ ///
207
+ /// # Headers Added
208
+ ///
209
+ /// - `Access-Control-Allow-Origin` - The origin that is allowed (if valid)
210
+ /// - `Access-Control-Expose-Headers` - Headers that are safe to expose to the client
211
+ /// - `Access-Control-Allow-Credentials` - "true" if credentials are allowed
212
+ ///
213
+ /// # Arguments
214
+ /// * `response` - Mutable reference to the response to modify
215
+ /// * `origin` - The origin from the request (e.g., `<https://example.com>`)
216
+ /// * `cors_config` - CORS configuration to apply
217
+ ///
218
+ /// # Example
219
+ ///
220
+ /// ```ignore
221
+ /// let mut response = Response::new(Body::empty());
222
+ /// add_cors_headers(&mut response, "https://example.com", &cors_config);
223
+ /// ```
224
+ pub fn add_cors_headers(response: &mut Response<Body>, origin: &str, cors_config: &CorsConfig) {
225
+ let headers = response.headers_mut();
226
+
227
+ if let Ok(origin_value) = HeaderValue::from_str(origin) {
228
+ headers.insert("access-control-allow-origin", origin_value);
229
+ }
230
+
231
+ if let Some(ref expose_headers) = cors_config.expose_headers {
232
+ let expose = expose_headers.join(", ");
233
+ if let Ok(expose_value) = HeaderValue::from_str(&expose) {
234
+ headers.insert("access-control-expose-headers", expose_value);
235
+ }
236
+ }
237
+
238
+ if let Some(true) = cors_config.allow_credentials {
239
+ headers.insert("access-control-allow-credentials", HeaderValue::from_static("true"));
240
+ }
241
+ }
242
+
243
+ /// Validate a non-preflight CORS request
244
+ ///
245
+ /// Checks if the Origin header is present and allowed for non-preflight (actual) requests.
246
+ /// Returns an error response if validation fails.
247
+ ///
248
+ /// # Validation
249
+ ///
250
+ /// - If no Origin header is present, the request is allowed (not a CORS request)
251
+ /// - If Origin header is present, it must match the allowed origins
252
+ ///
253
+ /// # Arguments
254
+ /// * `headers` - Request headers containing origin information
255
+ /// * `cors_config` - CORS configuration to validate against
256
+ ///
257
+ /// # Returns
258
+ /// * `Ok(())` - Request is allowed
259
+ /// * `Err(Response)` - 403 Forbidden with error details
260
+ ///
261
+ /// # Note
262
+ ///
263
+ /// This function is for actual requests, not OPTIONS preflight requests.
264
+ /// Use `handle_preflight` for OPTIONS requests.
265
+ pub fn validate_cors_request(headers: &HeaderMap, cors_config: &CorsConfig) -> Result<(), Box<Response<Body>>> {
266
+ let origin = headers.get("origin").and_then(|v| v.to_str().ok()).unwrap_or("");
267
+
268
+ if !origin.is_empty() && !is_origin_allowed(origin, &cors_config.allowed_origins) {
269
+ return Err(Box::new(
270
+ (
271
+ StatusCode::FORBIDDEN,
272
+ axum::Json(serde_json::json!({
273
+ "detail": format!("CORS request from origin '{}' not allowed", origin)
274
+ })),
275
+ )
276
+ .into_response(),
277
+ ));
278
+ }
279
+ Ok(())
280
+ }
281
+
282
+ #[cfg(test)]
283
+ mod tests {
284
+ use super::*;
285
+
286
+ fn make_cors_config() -> CorsConfig {
287
+ CorsConfig {
288
+ allowed_origins: vec!["https://example.com".to_string()],
289
+ allowed_methods: vec!["GET".to_string(), "POST".to_string()],
290
+ allowed_headers: vec!["content-type".to_string(), "authorization".to_string()],
291
+ expose_headers: Some(vec!["x-custom-header".to_string()]),
292
+ max_age: Some(3600),
293
+ allow_credentials: Some(true),
294
+ }
295
+ }
296
+
297
+ #[test]
298
+ fn test_is_origin_allowed_exact_match() {
299
+ let allowed = vec!["https://example.com".to_string()];
300
+ assert!(is_origin_allowed("https://example.com", &allowed));
301
+ assert!(!is_origin_allowed("https://evil.com", &allowed));
302
+ }
303
+
304
+ #[test]
305
+ fn test_is_origin_allowed_wildcard() {
306
+ let allowed = vec!["*".to_string()];
307
+ assert!(is_origin_allowed("https://example.com", &allowed));
308
+ assert!(is_origin_allowed("https://any-domain.com", &allowed));
309
+ }
310
+
311
+ #[test]
312
+ fn test_is_origin_allowed_empty_origin() {
313
+ let allowed = vec!["*".to_string()];
314
+ assert!(!is_origin_allowed("", &allowed));
315
+ }
316
+
317
+ #[test]
318
+ fn test_is_method_allowed_case_insensitive() {
319
+ let allowed = vec!["GET".to_string(), "POST".to_string()];
320
+ assert!(is_method_allowed("GET", &allowed));
321
+ assert!(is_method_allowed("get", &allowed));
322
+ assert!(is_method_allowed("POST", &allowed));
323
+ assert!(is_method_allowed("post", &allowed));
324
+ assert!(!is_method_allowed("DELETE", &allowed));
325
+ }
326
+
327
+ #[test]
328
+ fn test_is_method_allowed_wildcard() {
329
+ let allowed = vec!["*".to_string()];
330
+ assert!(is_method_allowed("GET", &allowed));
331
+ assert!(is_method_allowed("DELETE", &allowed));
332
+ assert!(is_method_allowed("PATCH", &allowed));
333
+ }
334
+
335
+ #[test]
336
+ fn test_are_headers_allowed_case_insensitive() {
337
+ let allowed = vec!["Content-Type".to_string(), "Authorization".to_string()];
338
+ assert!(are_headers_allowed(&["content-type"], &allowed));
339
+ assert!(are_headers_allowed(&["AUTHORIZATION"], &allowed));
340
+ assert!(are_headers_allowed(&["content-type", "authorization"], &allowed));
341
+ assert!(!are_headers_allowed(&["x-custom"], &allowed));
342
+ }
343
+
344
+ #[test]
345
+ fn test_are_headers_allowed_wildcard() {
346
+ let allowed = vec!["*".to_string()];
347
+ assert!(are_headers_allowed(&["any-header"], &allowed));
348
+ assert!(are_headers_allowed(&["multiple", "headers"], &allowed));
349
+ }
350
+
351
+ #[test]
352
+ fn test_handle_preflight_success() {
353
+ let config = make_cors_config();
354
+ let mut headers = HeaderMap::new();
355
+ headers.insert("origin", HeaderValue::from_static("https://example.com"));
356
+ headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
357
+ headers.insert(
358
+ "access-control-request-headers",
359
+ HeaderValue::from_static("content-type"),
360
+ );
361
+
362
+ let result = handle_preflight(&headers, &config);
363
+ assert!(result.is_ok());
364
+
365
+ let response = result.unwrap();
366
+ assert_eq!(response.status(), StatusCode::NO_CONTENT);
367
+
368
+ let resp_headers = response.headers();
369
+ assert_eq!(
370
+ resp_headers.get("access-control-allow-origin").unwrap(),
371
+ "https://example.com"
372
+ );
373
+ assert!(
374
+ resp_headers
375
+ .get("access-control-allow-methods")
376
+ .unwrap()
377
+ .to_str()
378
+ .unwrap()
379
+ .contains("POST")
380
+ );
381
+ assert_eq!(resp_headers.get("access-control-max-age").unwrap(), "3600");
382
+ assert_eq!(resp_headers.get("access-control-allow-credentials").unwrap(), "true");
383
+ }
384
+
385
+ #[test]
386
+ fn test_handle_preflight_origin_not_allowed() {
387
+ let config = make_cors_config();
388
+ let mut headers = HeaderMap::new();
389
+ headers.insert("origin", HeaderValue::from_static("https://evil.com"));
390
+ headers.insert("access-control-request-method", HeaderValue::from_static("GET"));
391
+
392
+ let result = handle_preflight(&headers, &config);
393
+ assert!(result.is_err());
394
+
395
+ let response = *result.unwrap_err();
396
+ assert_eq!(response.status(), StatusCode::FORBIDDEN);
397
+ }
398
+
399
+ #[test]
400
+ fn test_handle_preflight_method_not_allowed() {
401
+ let config = make_cors_config();
402
+ let mut headers = HeaderMap::new();
403
+ headers.insert("origin", HeaderValue::from_static("https://example.com"));
404
+ headers.insert("access-control-request-method", HeaderValue::from_static("DELETE"));
405
+
406
+ let result = handle_preflight(&headers, &config);
407
+ assert!(result.is_err());
408
+
409
+ let response = *result.unwrap_err();
410
+ assert_eq!(response.status(), StatusCode::FORBIDDEN);
411
+ }
412
+
413
+ #[test]
414
+ fn test_handle_preflight_header_not_allowed() {
415
+ let config = make_cors_config();
416
+ let mut headers = HeaderMap::new();
417
+ headers.insert("origin", HeaderValue::from_static("https://example.com"));
418
+ headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
419
+ headers.insert(
420
+ "access-control-request-headers",
421
+ HeaderValue::from_static("x-forbidden-header"),
422
+ );
423
+
424
+ let result = handle_preflight(&headers, &config);
425
+ assert!(result.is_err());
426
+
427
+ let response = *result.unwrap_err();
428
+ assert_eq!(response.status(), StatusCode::FORBIDDEN);
429
+ }
430
+
431
+ #[test]
432
+ fn test_handle_preflight_empty_origin() {
433
+ let config = make_cors_config();
434
+ let headers = HeaderMap::new();
435
+
436
+ let result = handle_preflight(&headers, &config);
437
+ assert!(result.is_err());
438
+
439
+ let response = *result.unwrap_err();
440
+ assert_eq!(response.status(), StatusCode::FORBIDDEN);
441
+ }
442
+
443
+ #[test]
444
+ fn test_add_cors_headers() {
445
+ let config = make_cors_config();
446
+ let mut response = Response::new(Body::empty());
447
+
448
+ add_cors_headers(&mut response, "https://example.com", &config);
449
+
450
+ let headers = response.headers();
451
+ assert_eq!(
452
+ headers.get("access-control-allow-origin").unwrap(),
453
+ "https://example.com"
454
+ );
455
+ assert_eq!(headers.get("access-control-expose-headers").unwrap(), "x-custom-header");
456
+ assert_eq!(headers.get("access-control-allow-credentials").unwrap(), "true");
457
+ }
458
+
459
+ #[test]
460
+ fn test_validate_cors_request_allowed() {
461
+ let config = make_cors_config();
462
+ let mut headers = HeaderMap::new();
463
+ headers.insert("origin", HeaderValue::from_static("https://example.com"));
464
+
465
+ let result = validate_cors_request(&headers, &config);
466
+ assert!(result.is_ok());
467
+ }
468
+
469
+ #[test]
470
+ fn test_validate_cors_request_not_allowed() {
471
+ let config = make_cors_config();
472
+ let mut headers = HeaderMap::new();
473
+ headers.insert("origin", HeaderValue::from_static("https://evil.com"));
474
+
475
+ let result = validate_cors_request(&headers, &config);
476
+ assert!(result.is_err());
477
+
478
+ let response = *result.unwrap_err();
479
+ assert_eq!(response.status(), StatusCode::FORBIDDEN);
480
+ }
481
+
482
+ #[test]
483
+ fn test_validate_cors_request_no_origin() {
484
+ let config = make_cors_config();
485
+ let headers = HeaderMap::new();
486
+
487
+ let result = validate_cors_request(&headers, &config);
488
+ assert!(result.is_ok());
489
+ }
490
+
491
+ // SECURITY TESTS: CORS Attack Vectors
492
+
493
+ /// SECURITY TEST: Verify credentials=true with wildcard is caught
494
+ /// This is a critical vulnerability - RFC 6454 forbids this
495
+ #[test]
496
+ fn test_credentials_with_wildcard_config_is_security_issue() {
497
+ let config = CorsConfig {
498
+ allowed_origins: vec!["*".to_string()],
499
+ allowed_methods: vec!["GET".to_string()],
500
+ allowed_headers: vec![],
501
+ expose_headers: None,
502
+ max_age: None,
503
+ allow_credentials: Some(true), // SECURITY BUG: This should not be allowed with wildcard
504
+ };
505
+
506
+ let mut headers = HeaderMap::new();
507
+ headers.insert("origin", HeaderValue::from_static("https://evil.com"));
508
+ headers.insert("access-control-request-method", HeaderValue::from_static("GET"));
509
+
510
+ let result = handle_preflight(&headers, &config);
511
+
512
+ // BUG: This should return 500 or reject the config, but instead succeeds
513
+ if let Ok(response) = result {
514
+ let resp_headers = response.headers();
515
+ let has_credentials = resp_headers
516
+ .get("access-control-allow-credentials")
517
+ .map(|v| v == "true")
518
+ .unwrap_or(false);
519
+ let origin_header = resp_headers.get("access-control-allow-origin");
520
+
521
+ if has_credentials && origin_header.is_some() {
522
+ let origin_val = origin_header.unwrap().to_str().unwrap_or("");
523
+ if origin_val == "*" {
524
+ panic!("SECURITY VULNERABILITY: credentials=true with origin=* allowed");
525
+ }
526
+ }
527
+ }
528
+ }
529
+
530
+ /// SECURITY TEST: Exact origin matching required
531
+ /// Subdomain like api.evil.example.com must NOT match example.com
532
+ #[test]
533
+ fn test_subdomain_bypass_blocked() {
534
+ let config = CorsConfig {
535
+ allowed_origins: vec!["https://example.com".to_string()],
536
+ allowed_methods: vec!["GET".to_string()],
537
+ allowed_headers: vec![],
538
+ expose_headers: None,
539
+ max_age: None,
540
+ allow_credentials: None,
541
+ };
542
+
543
+ assert!(!is_origin_allowed("https://api.example.com", &config.allowed_origins));
544
+ assert!(!is_origin_allowed("https://evil.example.com", &config.allowed_origins));
545
+ assert!(!is_origin_allowed(
546
+ "https://sub.sub.example.com",
547
+ &config.allowed_origins
548
+ ));
549
+
550
+ assert!(is_origin_allowed("https://example.com", &config.allowed_origins));
551
+ }
552
+
553
+ /// SECURITY TEST: Port exact matching required
554
+ /// localhost:3001 must NOT match localhost:3000
555
+ #[test]
556
+ fn test_port_bypass_blocked() {
557
+ let config = CorsConfig {
558
+ allowed_origins: vec!["http://localhost:3000".to_string()],
559
+ allowed_methods: vec!["GET".to_string()],
560
+ allowed_headers: vec![],
561
+ expose_headers: None,
562
+ max_age: None,
563
+ allow_credentials: None,
564
+ };
565
+
566
+ assert!(!is_origin_allowed("http://localhost:3001", &config.allowed_origins));
567
+ assert!(!is_origin_allowed("http://localhost:8080", &config.allowed_origins));
568
+ assert!(!is_origin_allowed("http://localhost:443", &config.allowed_origins));
569
+
570
+ assert!(is_origin_allowed("http://localhost:3000", &config.allowed_origins));
571
+ }
572
+
573
+ /// SECURITY TEST: Protocol exact matching required
574
+ /// http://example.com must NOT match https://example.com
575
+ #[test]
576
+ fn test_protocol_downgrade_attack_blocked() {
577
+ let config = CorsConfig {
578
+ allowed_origins: vec!["https://example.com".to_string()],
579
+ allowed_methods: vec!["GET".to_string()],
580
+ allowed_headers: vec![],
581
+ expose_headers: None,
582
+ max_age: None,
583
+ allow_credentials: None,
584
+ };
585
+
586
+ assert!(!is_origin_allowed("http://example.com", &config.allowed_origins));
587
+ assert!(!is_origin_allowed("ws://example.com", &config.allowed_origins));
588
+ assert!(!is_origin_allowed("wss://example.com", &config.allowed_origins));
589
+
590
+ assert!(is_origin_allowed("https://example.com", &config.allowed_origins));
591
+ }
592
+
593
+ /// SECURITY TEST: Case sensitivity in origin matching
594
+ /// Origins should match exactly (including case)
595
+ #[test]
596
+ fn test_case_sensitive_origin_matching() {
597
+ let config = CorsConfig {
598
+ allowed_origins: vec!["https://Example.Com".to_string()],
599
+ allowed_methods: vec!["GET".to_string()],
600
+ allowed_headers: vec![],
601
+ expose_headers: None,
602
+ max_age: None,
603
+ allow_credentials: None,
604
+ };
605
+
606
+ assert!(!is_origin_allowed("https://example.com", &config.allowed_origins));
607
+ assert!(!is_origin_allowed("https://EXAMPLE.COM", &config.allowed_origins));
608
+
609
+ assert!(is_origin_allowed("https://Example.Com", &config.allowed_origins));
610
+ }
611
+
612
+ /// SECURITY TEST: Trailing slash normalization
613
+ /// https://example.com/ should be treated differently from https://example.com
614
+ #[test]
615
+ fn test_trailing_slash_origin_not_normalized() {
616
+ let config = CorsConfig {
617
+ allowed_origins: vec!["https://example.com".to_string()],
618
+ allowed_methods: vec!["GET".to_string()],
619
+ allowed_headers: vec![],
620
+ expose_headers: None,
621
+ max_age: None,
622
+ allow_credentials: None,
623
+ };
624
+
625
+ assert!(!is_origin_allowed("https://example.com/", &config.allowed_origins));
626
+
627
+ assert!(is_origin_allowed("https://example.com", &config.allowed_origins));
628
+ }
629
+
630
+ /// SECURITY TEST: NULL origin and wildcard behavior
631
+ /// Special "null" origin used by file:// and sandboxed iframes
632
+ /// The current implementation treats "null" as a regular origin string,
633
+ /// which means it IS allowed by wildcard (not ideal but documents current behavior)
634
+ #[test]
635
+ fn test_null_origin_with_wildcard() {
636
+ let config = CorsConfig {
637
+ allowed_origins: vec!["*".to_string()],
638
+ allowed_methods: vec!["GET".to_string()],
639
+ allowed_headers: vec![],
640
+ expose_headers: None,
641
+ max_age: None,
642
+ allow_credentials: None,
643
+ };
644
+
645
+ // SECURITY NOTE: "null" origin is allowed by wildcard in current implementation
646
+ assert!(is_origin_allowed("null", &config.allowed_origins));
647
+
648
+ let with_explicit_null = CorsConfig {
649
+ allowed_origins: vec!["null".to_string()],
650
+ allowed_methods: vec!["GET".to_string()],
651
+ allowed_headers: vec![],
652
+ expose_headers: None,
653
+ max_age: None,
654
+ allow_credentials: None,
655
+ };
656
+ assert!(is_origin_allowed("null", &with_explicit_null.allowed_origins));
657
+ }
658
+
659
+ /// SECURITY TEST: Empty origin is always rejected
660
+ #[test]
661
+ fn test_empty_origin_always_rejected() {
662
+ let config_with_wildcard = CorsConfig {
663
+ allowed_origins: vec!["*".to_string()],
664
+ allowed_methods: vec!["GET".to_string()],
665
+ allowed_headers: vec![],
666
+ expose_headers: None,
667
+ max_age: None,
668
+ allow_credentials: None,
669
+ };
670
+ assert!(!is_origin_allowed("", &config_with_wildcard.allowed_origins));
671
+
672
+ let config_with_explicit = CorsConfig {
673
+ allowed_origins: vec!["https://example.com".to_string()],
674
+ allowed_methods: vec!["GET".to_string()],
675
+ allowed_headers: vec![],
676
+ expose_headers: None,
677
+ max_age: None,
678
+ allow_credentials: None,
679
+ };
680
+ assert!(!is_origin_allowed("", &config_with_explicit.allowed_origins));
681
+ }
682
+
683
+ /// SECURITY TEST: Preflight with invalid origin should reject
684
+ #[test]
685
+ fn test_preflight_rejects_invalid_origin() {
686
+ let config = make_cors_config();
687
+ let mut headers = HeaderMap::new();
688
+ headers.insert("origin", HeaderValue::from_static("https://untrusted.com"));
689
+ headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
690
+
691
+ let result = handle_preflight(&headers, &config);
692
+ assert!(result.is_err());
693
+
694
+ let response = *result.unwrap_err();
695
+ assert_eq!(response.status(), StatusCode::FORBIDDEN);
696
+ }
697
+
698
+ /// SECURITY TEST: Multiple origins - each must be exact match
699
+ #[test]
700
+ fn test_multiple_origins_exact_matching() {
701
+ let config = CorsConfig {
702
+ allowed_origins: vec!["https://trusted1.com".to_string(), "https://trusted2.com".to_string()],
703
+ allowed_methods: vec!["GET".to_string()],
704
+ allowed_headers: vec![],
705
+ expose_headers: None,
706
+ max_age: None,
707
+ allow_credentials: None,
708
+ };
709
+
710
+ assert!(is_origin_allowed("https://trusted1.com", &config.allowed_origins));
711
+ assert!(is_origin_allowed("https://trusted2.com", &config.allowed_origins));
712
+
713
+ assert!(!is_origin_allowed(
714
+ "https://trusted1.com.evil.com",
715
+ &config.allowed_origins
716
+ ));
717
+ assert!(!is_origin_allowed("https://trusted3.com", &config.allowed_origins));
718
+ assert!(!is_origin_allowed("https://trusted.com", &config.allowed_origins));
719
+ }
720
+
721
+ /// SECURITY TEST: Wildcard origin should allow any origin (but check config)
722
+ #[test]
723
+ fn test_wildcard_allows_all_but_check_credentials() {
724
+ let config = CorsConfig {
725
+ allowed_origins: vec!["*".to_string()],
726
+ allowed_methods: vec!["GET".to_string()],
727
+ allowed_headers: vec![],
728
+ expose_headers: None,
729
+ max_age: None,
730
+ allow_credentials: None,
731
+ };
732
+
733
+ assert!(is_origin_allowed("https://example.com", &config.allowed_origins));
734
+ assert!(is_origin_allowed("https://evil.com", &config.allowed_origins));
735
+ assert!(is_origin_allowed("http://localhost:3000", &config.allowed_origins));
736
+
737
+ assert!(!is_origin_allowed("", &config.allowed_origins));
738
+ }
739
+
740
+ /// SECURITY TEST: Preflight response headers must match config exactly
741
+ #[test]
742
+ fn test_preflight_response_has_correct_allowed_origins() {
743
+ let config = CorsConfig {
744
+ allowed_origins: vec!["https://trusted.com".to_string()],
745
+ allowed_methods: vec!["GET".to_string(), "POST".to_string()],
746
+ allowed_headers: vec!["content-type".to_string()],
747
+ expose_headers: None,
748
+ max_age: Some(3600),
749
+ allow_credentials: Some(false),
750
+ };
751
+
752
+ let mut headers = HeaderMap::new();
753
+ headers.insert("origin", HeaderValue::from_static("https://trusted.com"));
754
+ headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
755
+ headers.insert(
756
+ "access-control-request-headers",
757
+ HeaderValue::from_static("content-type"),
758
+ );
759
+
760
+ let result = handle_preflight(&headers, &config);
761
+ assert!(result.is_ok());
762
+
763
+ let response = result.unwrap();
764
+ let resp_headers = response.headers();
765
+
766
+ assert_eq!(
767
+ resp_headers.get("access-control-allow-origin").unwrap(),
768
+ "https://trusted.com"
769
+ );
770
+
771
+ assert!(
772
+ resp_headers
773
+ .get("access-control-allow-methods")
774
+ .unwrap()
775
+ .to_str()
776
+ .unwrap()
777
+ .contains("GET")
778
+ );
779
+ assert!(
780
+ resp_headers
781
+ .get("access-control-allow-methods")
782
+ .unwrap()
783
+ .to_str()
784
+ .unwrap()
785
+ .contains("POST")
786
+ );
787
+
788
+ assert!(resp_headers.get("access-control-allow-credentials").is_none());
789
+ }
790
+
791
+ /// SECURITY TEST: Origin not in allowed list must be rejected in preflight
792
+ #[test]
793
+ fn test_preflight_all_origins_require_validation() {
794
+ let config = CorsConfig {
795
+ allowed_origins: vec!["https://trusted.com".to_string()],
796
+ allowed_methods: vec!["GET".to_string()],
797
+ allowed_headers: vec![],
798
+ expose_headers: None,
799
+ max_age: None,
800
+ allow_credentials: None,
801
+ };
802
+
803
+ let test_cases = vec![
804
+ "https://trusted.com",
805
+ "https://evil.com",
806
+ "https://trusted.com.evil",
807
+ "http://trusted.com",
808
+ "",
809
+ ];
810
+
811
+ for origin in test_cases {
812
+ let mut headers = HeaderMap::new();
813
+ headers.insert(
814
+ "origin",
815
+ HeaderValue::from_str(origin).unwrap_or_else(|_| HeaderValue::from_static("")),
816
+ );
817
+ headers.insert("access-control-request-method", HeaderValue::from_static("GET"));
818
+
819
+ let result = handle_preflight(&headers, &config);
820
+
821
+ if origin == "https://trusted.com" {
822
+ assert!(result.is_ok(), "Valid origin {} should be allowed", origin);
823
+ } else {
824
+ assert!(result.is_err(), "Invalid origin {} should be rejected", origin);
825
+ }
826
+ }
827
+ }
828
+
829
+ /// SECURITY TEST: Requested headers must be in allowed list
830
+ #[test]
831
+ fn test_preflight_validates_all_requested_headers() {
832
+ let config = CorsConfig {
833
+ allowed_origins: vec!["https://trusted.com".to_string()],
834
+ allowed_methods: vec!["POST".to_string()],
835
+ allowed_headers: vec!["content-type".to_string(), "authorization".to_string()],
836
+ expose_headers: None,
837
+ max_age: None,
838
+ allow_credentials: None,
839
+ };
840
+
841
+ let test_cases = vec![
842
+ ("content-type", true),
843
+ ("authorization", true),
844
+ ("content-type, authorization", true),
845
+ ("x-custom-header", false),
846
+ ("content-type, x-custom", false),
847
+ ];
848
+
849
+ for (headers_str, should_pass) in test_cases {
850
+ let mut headers = HeaderMap::new();
851
+ headers.insert("origin", HeaderValue::from_static("https://trusted.com"));
852
+ headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
853
+ headers.insert(
854
+ "access-control-request-headers",
855
+ HeaderValue::from_str(headers_str).unwrap(),
856
+ );
857
+
858
+ let result = handle_preflight(&headers, &config);
859
+
860
+ if should_pass {
861
+ assert!(
862
+ result.is_ok(),
863
+ "Preflight with valid headers '{}' should pass",
864
+ headers_str
865
+ );
866
+ } else {
867
+ assert!(
868
+ result.is_err(),
869
+ "Preflight with invalid headers '{}' should fail",
870
+ headers_str
871
+ );
872
+ }
873
+ }
874
+ }
875
+
876
+ /// SECURITY TEST: add_cors_headers should respect origin validation
877
+ #[test]
878
+ fn test_add_cors_headers_respects_origin() {
879
+ let config = CorsConfig {
880
+ allowed_origins: vec!["https://trusted.com".to_string()],
881
+ allowed_methods: vec!["GET".to_string()],
882
+ allowed_headers: vec![],
883
+ expose_headers: Some(vec!["x-custom".to_string()]),
884
+ max_age: None,
885
+ allow_credentials: Some(true),
886
+ };
887
+
888
+ let mut response = Response::new(Body::empty());
889
+
890
+ add_cors_headers(&mut response, "https://trusted.com", &config);
891
+
892
+ let headers = response.headers();
893
+ assert_eq!(
894
+ headers.get("access-control-allow-origin").unwrap(),
895
+ "https://trusted.com"
896
+ );
897
+ assert_eq!(headers.get("access-control-expose-headers").unwrap(), "x-custom");
898
+ assert_eq!(headers.get("access-control-allow-credentials").unwrap(), "true");
899
+ }
900
+
901
+ /// SECURITY TEST: validate_cors_request respects allowed origins
902
+ #[test]
903
+ fn test_validate_cors_request_origin_must_match() {
904
+ let config = CorsConfig {
905
+ allowed_origins: vec!["https://trusted.com".to_string()],
906
+ allowed_methods: vec!["GET".to_string()],
907
+ allowed_headers: vec![],
908
+ expose_headers: None,
909
+ max_age: None,
910
+ allow_credentials: None,
911
+ };
912
+
913
+ let mut headers = HeaderMap::new();
914
+ headers.insert("origin", HeaderValue::from_static("https://trusted.com"));
915
+ assert!(validate_cors_request(&headers, &config).is_ok());
916
+
917
+ let mut headers = HeaderMap::new();
918
+ headers.insert("origin", HeaderValue::from_static("https://evil.com"));
919
+ assert!(validate_cors_request(&headers, &config).is_err());
920
+
921
+ let headers = HeaderMap::new();
922
+ assert!(validate_cors_request(&headers, &config).is_ok());
923
+ }
924
+
925
+ /// SECURITY TEST: Preflight without requested method should fail
926
+ #[test]
927
+ fn test_preflight_requires_access_control_request_method() {
928
+ let config = make_cors_config();
929
+ let mut headers = HeaderMap::new();
930
+ headers.insert("origin", HeaderValue::from_static("https://example.com"));
931
+
932
+ let result = handle_preflight(&headers, &config);
933
+ assert!(result.is_ok());
934
+ }
935
+
936
+ /// SECURITY TEST: Case-insensitive method matching
937
+ #[test]
938
+ fn test_preflight_method_case_insensitive() {
939
+ let config = CorsConfig {
940
+ allowed_origins: vec!["https://example.com".to_string()],
941
+ allowed_methods: vec!["GET".to_string(), "POST".to_string()],
942
+ allowed_headers: vec![],
943
+ expose_headers: None,
944
+ max_age: None,
945
+ allow_credentials: None,
946
+ };
947
+
948
+ let test_cases = vec!["GET", "get", "Get", "POST", "post"];
949
+
950
+ for method in test_cases {
951
+ let mut headers = HeaderMap::new();
952
+ headers.insert("origin", HeaderValue::from_static("https://example.com"));
953
+ headers.insert("access-control-request-method", HeaderValue::from_str(method).unwrap());
954
+
955
+ let result = handle_preflight(&headers, &config);
956
+ assert!(
957
+ result.is_ok(),
958
+ "Method '{}' should be allowed (case-insensitive)",
959
+ method
960
+ );
961
+ }
962
+ }
963
+
964
+ /// SECURITY TEST: Ensure preflight max-age is set correctly
965
+ #[test]
966
+ fn test_preflight_max_age_header() {
967
+ let config = CorsConfig {
968
+ allowed_origins: vec!["https://example.com".to_string()],
969
+ allowed_methods: vec!["GET".to_string()],
970
+ allowed_headers: vec![],
971
+ expose_headers: None,
972
+ max_age: Some(7200),
973
+ allow_credentials: None,
974
+ };
975
+
976
+ let mut headers = HeaderMap::new();
977
+ headers.insert("origin", HeaderValue::from_static("https://example.com"));
978
+ headers.insert("access-control-request-method", HeaderValue::from_static("GET"));
979
+
980
+ let result = handle_preflight(&headers, &config);
981
+ assert!(result.is_ok());
982
+
983
+ let response = result.unwrap();
984
+ assert_eq!(response.headers().get("access-control-max-age").unwrap(), "7200");
985
+ }
986
+
987
+ /// SECURITY TEST: Wildcard partial patterns should not work
988
+ /// *.example.com style patterns are not supported (good!)
989
+ #[test]
990
+ fn test_wildcard_patterns_not_supported() {
991
+ let config = CorsConfig {
992
+ allowed_origins: vec!["*.example.com".to_string()],
993
+ allowed_methods: vec!["GET".to_string()],
994
+ allowed_headers: vec![],
995
+ expose_headers: None,
996
+ max_age: None,
997
+ allow_credentials: None,
998
+ };
999
+
1000
+ assert!(!is_origin_allowed("https://api.example.com", &config.allowed_origins));
1001
+ assert!(!is_origin_allowed("https://example.com", &config.allowed_origins));
1002
+
1003
+ assert!(is_origin_allowed("*.example.com", &config.allowed_origins));
1004
+ }
1005
+ }