spikard 0.5.0 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +674 -674
- data/ext/spikard_rb/Cargo.toml +17 -17
- data/ext/spikard_rb/extconf.rb +13 -10
- data/ext/spikard_rb/src/lib.rs +6 -6
- data/lib/spikard/app.rb +405 -405
- data/lib/spikard/background.rb +27 -27
- data/lib/spikard/config.rb +396 -396
- data/lib/spikard/converters.rb +13 -13
- data/lib/spikard/handler_wrapper.rb +113 -113
- data/lib/spikard/provide.rb +214 -214
- data/lib/spikard/response.rb +173 -173
- data/lib/spikard/schema.rb +243 -243
- data/lib/spikard/sse.rb +111 -111
- data/lib/spikard/streaming_response.rb +44 -44
- data/lib/spikard/testing.rb +256 -256
- data/lib/spikard/upload_file.rb +131 -131
- data/lib/spikard/version.rb +5 -5
- data/lib/spikard/websocket.rb +59 -59
- data/lib/spikard.rb +43 -43
- data/sig/spikard.rbs +366 -366
- data/vendor/crates/spikard-bindings-shared/Cargo.toml +63 -63
- data/vendor/crates/spikard-bindings-shared/examples/config_extraction.rs +132 -132
- data/vendor/crates/spikard-bindings-shared/src/config_extractor.rs +752 -752
- data/vendor/crates/spikard-bindings-shared/src/conversion_traits.rs +194 -194
- data/vendor/crates/spikard-bindings-shared/src/di_traits.rs +246 -246
- data/vendor/crates/spikard-bindings-shared/src/error_response.rs +401 -401
- data/vendor/crates/spikard-bindings-shared/src/handler_base.rs +238 -238
- data/vendor/crates/spikard-bindings-shared/src/lib.rs +24 -24
- data/vendor/crates/spikard-bindings-shared/src/lifecycle_base.rs +292 -292
- data/vendor/crates/spikard-bindings-shared/src/lifecycle_executor.rs +616 -616
- data/vendor/crates/spikard-bindings-shared/src/response_builder.rs +305 -305
- data/vendor/crates/spikard-bindings-shared/src/test_client_base.rs +248 -248
- data/vendor/crates/spikard-bindings-shared/src/validation_helpers.rs +351 -351
- data/vendor/crates/spikard-bindings-shared/tests/comprehensive_coverage.rs +454 -454
- data/vendor/crates/spikard-bindings-shared/tests/error_response_edge_cases.rs +383 -383
- data/vendor/crates/spikard-bindings-shared/tests/handler_base_integration.rs +280 -280
- data/vendor/crates/spikard-core/Cargo.toml +40 -40
- data/vendor/crates/spikard-core/src/bindings/mod.rs +3 -3
- data/vendor/crates/spikard-core/src/bindings/response.rs +133 -133
- data/vendor/crates/spikard-core/src/debug.rs +127 -127
- data/vendor/crates/spikard-core/src/di/container.rs +702 -702
- data/vendor/crates/spikard-core/src/di/dependency.rs +273 -273
- data/vendor/crates/spikard-core/src/di/error.rs +118 -118
- data/vendor/crates/spikard-core/src/di/factory.rs +534 -534
- data/vendor/crates/spikard-core/src/di/graph.rs +506 -506
- data/vendor/crates/spikard-core/src/di/mod.rs +192 -192
- data/vendor/crates/spikard-core/src/di/resolved.rs +405 -405
- data/vendor/crates/spikard-core/src/di/value.rs +281 -281
- data/vendor/crates/spikard-core/src/errors.rs +69 -69
- data/vendor/crates/spikard-core/src/http.rs +415 -415
- data/vendor/crates/spikard-core/src/lib.rs +29 -29
- data/vendor/crates/spikard-core/src/lifecycle.rs +1186 -1186
- data/vendor/crates/spikard-core/src/metadata.rs +389 -389
- data/vendor/crates/spikard-core/src/parameters.rs +2525 -2525
- data/vendor/crates/spikard-core/src/problem.rs +344 -344
- data/vendor/crates/spikard-core/src/request_data.rs +1154 -1154
- data/vendor/crates/spikard-core/src/router.rs +510 -510
- data/vendor/crates/spikard-core/src/schema_registry.rs +183 -183
- data/vendor/crates/spikard-core/src/type_hints.rs +304 -304
- data/vendor/crates/spikard-core/src/validation/error_mapper.rs +696 -688
- data/vendor/crates/spikard-core/src/validation/mod.rs +457 -457
- data/vendor/crates/spikard-http/Cargo.toml +62 -64
- data/vendor/crates/spikard-http/examples/sse-notifications.rs +148 -148
- data/vendor/crates/spikard-http/examples/websocket-chat.rs +92 -92
- data/vendor/crates/spikard-http/src/auth.rs +296 -296
- data/vendor/crates/spikard-http/src/background.rs +1860 -1860
- data/vendor/crates/spikard-http/src/bindings/mod.rs +3 -3
- data/vendor/crates/spikard-http/src/bindings/response.rs +1 -1
- data/vendor/crates/spikard-http/src/body_metadata.rs +8 -8
- data/vendor/crates/spikard-http/src/cors.rs +1005 -1005
- data/vendor/crates/spikard-http/src/debug.rs +128 -128
- data/vendor/crates/spikard-http/src/di_handler.rs +1668 -1668
- data/vendor/crates/spikard-http/src/handler_response.rs +901 -901
- data/vendor/crates/spikard-http/src/handler_trait.rs +838 -830
- data/vendor/crates/spikard-http/src/handler_trait_tests.rs +290 -290
- data/vendor/crates/spikard-http/src/lib.rs +534 -534
- data/vendor/crates/spikard-http/src/lifecycle/adapter.rs +230 -230
- data/vendor/crates/spikard-http/src/lifecycle.rs +1193 -1193
- data/vendor/crates/spikard-http/src/middleware/mod.rs +560 -540
- data/vendor/crates/spikard-http/src/middleware/multipart.rs +912 -912
- data/vendor/crates/spikard-http/src/middleware/urlencoded.rs +513 -513
- data/vendor/crates/spikard-http/src/middleware/validation.rs +768 -735
- data/vendor/crates/spikard-http/src/openapi/mod.rs +309 -309
- data/vendor/crates/spikard-http/src/openapi/parameter_extraction.rs +535 -535
- data/vendor/crates/spikard-http/src/openapi/schema_conversion.rs +1363 -1363
- data/vendor/crates/spikard-http/src/openapi/spec_generation.rs +665 -665
- data/vendor/crates/spikard-http/src/query_parser.rs +793 -793
- data/vendor/crates/spikard-http/src/response.rs +720 -720
- data/vendor/crates/spikard-http/src/server/handler.rs +1650 -1650
- data/vendor/crates/spikard-http/src/server/lifecycle_execution.rs +234 -234
- data/vendor/crates/spikard-http/src/server/mod.rs +1593 -1502
- data/vendor/crates/spikard-http/src/server/request_extraction.rs +789 -770
- data/vendor/crates/spikard-http/src/server/routing_factory.rs +629 -599
- data/vendor/crates/spikard-http/src/sse.rs +1409 -1409
- data/vendor/crates/spikard-http/src/testing/form.rs +52 -52
- data/vendor/crates/spikard-http/src/testing/multipart.rs +64 -60
- data/vendor/crates/spikard-http/src/testing/test_client.rs +311 -283
- data/vendor/crates/spikard-http/src/testing.rs +406 -377
- data/vendor/crates/spikard-http/src/websocket.rs +1404 -1375
- data/vendor/crates/spikard-http/tests/background_behavior.rs +832 -832
- data/vendor/crates/spikard-http/tests/common/handlers.rs +309 -309
- data/vendor/crates/spikard-http/tests/common/mod.rs +26 -26
- data/vendor/crates/spikard-http/tests/di_integration.rs +192 -192
- data/vendor/crates/spikard-http/tests/doc_snippets.rs +5 -5
- data/vendor/crates/spikard-http/tests/lifecycle_execution.rs +1093 -1093
- data/vendor/crates/spikard-http/tests/multipart_behavior.rs +656 -656
- data/vendor/crates/spikard-http/tests/server_config_builder.rs +314 -314
- data/vendor/crates/spikard-http/tests/sse_behavior.rs +620 -620
- data/vendor/crates/spikard-http/tests/websocket_behavior.rs +663 -663
- data/vendor/crates/spikard-rb/Cargo.toml +48 -48
- data/vendor/crates/spikard-rb/build.rs +199 -199
- data/vendor/crates/spikard-rb/src/background.rs +63 -63
- data/vendor/crates/spikard-rb/src/config/mod.rs +5 -5
- data/vendor/crates/spikard-rb/src/config/server_config.rs +285 -285
- data/vendor/crates/spikard-rb/src/conversion.rs +554 -554
- data/vendor/crates/spikard-rb/src/di/builder.rs +100 -100
- data/vendor/crates/spikard-rb/src/di/mod.rs +375 -375
- data/vendor/crates/spikard-rb/src/handler.rs +618 -618
- data/vendor/crates/spikard-rb/src/integration/mod.rs +3 -3
- data/vendor/crates/spikard-rb/src/lib.rs +1806 -1810
- data/vendor/crates/spikard-rb/src/lifecycle.rs +275 -275
- data/vendor/crates/spikard-rb/src/metadata/mod.rs +5 -5
- data/vendor/crates/spikard-rb/src/metadata/route_extraction.rs +442 -447
- data/vendor/crates/spikard-rb/src/runtime/mod.rs +5 -5
- data/vendor/crates/spikard-rb/src/runtime/server_runner.rs +324 -324
- data/vendor/crates/spikard-rb/src/server.rs +305 -308
- data/vendor/crates/spikard-rb/src/sse.rs +231 -231
- data/vendor/crates/spikard-rb/src/testing/client.rs +538 -551
- data/vendor/crates/spikard-rb/src/testing/mod.rs +7 -7
- data/vendor/crates/spikard-rb/src/testing/sse.rs +143 -143
- data/vendor/crates/spikard-rb/src/testing/websocket.rs +608 -635
- data/vendor/crates/spikard-rb/src/websocket.rs +377 -374
- metadata +15 -1
|
@@ -1,1005 +1,1005 @@
|
|
|
1
|
-
//! CORS (Cross-Origin Resource Sharing) handling
|
|
2
|
-
//!
|
|
3
|
-
//! Handles CORS preflight requests and adds CORS headers to responses
|
|
4
|
-
|
|
5
|
-
use crate::CorsConfig;
|
|
6
|
-
use axum::body::Body;
|
|
7
|
-
use axum::http::{HeaderMap, HeaderValue, Response, StatusCode};
|
|
8
|
-
use axum::response::IntoResponse;
|
|
9
|
-
|
|
10
|
-
/// Check if an origin is allowed by the CORS configuration
|
|
11
|
-
///
|
|
12
|
-
/// Supports exact matches and wildcard ("*") for any origin.
|
|
13
|
-
/// Empty origins always return false for security.
|
|
14
|
-
///
|
|
15
|
-
/// # Arguments
|
|
16
|
-
/// * `origin` - The origin string from the HTTP request (e.g., "https://example.com")
|
|
17
|
-
/// * `allowed_origins` - List of allowed origins configured for CORS
|
|
18
|
-
///
|
|
19
|
-
/// # Returns
|
|
20
|
-
/// `true` if the origin is allowed, `false` otherwise
|
|
21
|
-
fn is_origin_allowed(origin: &str, allowed_origins: &[String]) -> bool {
|
|
22
|
-
if origin.is_empty() {
|
|
23
|
-
return false;
|
|
24
|
-
}
|
|
25
|
-
|
|
26
|
-
allowed_origins
|
|
27
|
-
.iter()
|
|
28
|
-
.any(|allowed| allowed == "*" || allowed == origin)
|
|
29
|
-
}
|
|
30
|
-
|
|
31
|
-
/// Check if a method is allowed by the CORS configuration
|
|
32
|
-
///
|
|
33
|
-
/// Supports exact matches and wildcard ("*") for any method.
|
|
34
|
-
/// Comparison is case-insensitive (e.g., "get" matches "GET").
|
|
35
|
-
///
|
|
36
|
-
/// # Arguments
|
|
37
|
-
/// * `method` - The HTTP method requested (e.g., "GET", "POST")
|
|
38
|
-
/// * `allowed_methods` - List of allowed HTTP methods configured for CORS
|
|
39
|
-
///
|
|
40
|
-
/// # Returns
|
|
41
|
-
/// `true` if the method is allowed, `false` otherwise
|
|
42
|
-
fn is_method_allowed(method: &str, allowed_methods: &[String]) -> bool {
|
|
43
|
-
allowed_methods
|
|
44
|
-
.iter()
|
|
45
|
-
.any(|allowed| allowed == "*" || allowed.eq_ignore_ascii_case(method))
|
|
46
|
-
}
|
|
47
|
-
|
|
48
|
-
/// Check if all requested headers are allowed by CORS configuration
|
|
49
|
-
///
|
|
50
|
-
/// Headers are case-insensitive. Supports wildcard ("*") for allowing any header.
|
|
51
|
-
/// If a wildcard is configured, all requested headers are allowed.
|
|
52
|
-
///
|
|
53
|
-
/// # Arguments
|
|
54
|
-
/// * `requested` - Array of header names requested by the client
|
|
55
|
-
/// * `allowed` - List of allowed header names configured for CORS
|
|
56
|
-
///
|
|
57
|
-
/// # Returns
|
|
58
|
-
/// `true` if all requested headers are allowed, `false` if any header is not allowed
|
|
59
|
-
fn are_headers_allowed(requested: &[&str], allowed: &[String]) -> bool {
|
|
60
|
-
if allowed.iter().any(|h| h == "*") {
|
|
61
|
-
return true;
|
|
62
|
-
}
|
|
63
|
-
|
|
64
|
-
requested.iter().all(|req_header| {
|
|
65
|
-
allowed
|
|
66
|
-
.iter()
|
|
67
|
-
.any(|allowed_header| allowed_header.eq_ignore_ascii_case(req_header))
|
|
68
|
-
})
|
|
69
|
-
}
|
|
70
|
-
|
|
71
|
-
/// Handle CORS preflight (OPTIONS) request
|
|
72
|
-
///
|
|
73
|
-
/// Validates the request against the CORS configuration and returns appropriate
|
|
74
|
-
/// response or error. This function processes OPTIONS requests as defined in the
|
|
75
|
-
/// CORS specification (RFC 7231).
|
|
76
|
-
///
|
|
77
|
-
/// # Validation
|
|
78
|
-
///
|
|
79
|
-
/// Checks the following conditions:
|
|
80
|
-
/// 1. **Origin Header:** Must be present and match configured allowed origins
|
|
81
|
-
/// 2. **Access-Control-Request-Method:** Must match configured allowed methods
|
|
82
|
-
/// 3. **Access-Control-Request-Headers:** All requested headers must match configured allowed headers
|
|
83
|
-
///
|
|
84
|
-
/// # Success Response
|
|
85
|
-
///
|
|
86
|
-
/// Returns HTTP 204 (No Content) with the following response headers:
|
|
87
|
-
/// - `Access-Control-Allow-Origin` - The origin that is allowed
|
|
88
|
-
/// - `Access-Control-Allow-Methods` - Comma-separated list of allowed methods
|
|
89
|
-
/// - `Access-Control-Allow-Headers` - Comma-separated list of allowed headers
|
|
90
|
-
/// - `Access-Control-Max-Age` - Caching duration in seconds (if configured)
|
|
91
|
-
/// - `Access-Control-Allow-Credentials` - "true" if credentials are allowed
|
|
92
|
-
///
|
|
93
|
-
/// # Error Response
|
|
94
|
-
///
|
|
95
|
-
/// Returns HTTP 403 (Forbidden) if validation fails for:
|
|
96
|
-
/// - Origin not in allowed list
|
|
97
|
-
/// - Requested method not allowed
|
|
98
|
-
/// - Requested headers not allowed
|
|
99
|
-
///
|
|
100
|
-
/// # Arguments
|
|
101
|
-
/// * `headers` - Request headers containing CORS preflight information
|
|
102
|
-
/// * `cors_config` - CORS configuration to validate against
|
|
103
|
-
///
|
|
104
|
-
/// # Returns
|
|
105
|
-
/// * `Ok(Response)` - 204 No Content with CORS headers
|
|
106
|
-
/// * `Err(Response)` - 403 Forbidden or 500 Internal Server Error
|
|
107
|
-
pub fn handle_preflight(headers: &HeaderMap, cors_config: &CorsConfig) -> Result<Response<Body>, Box<Response<Body>>> {
|
|
108
|
-
let origin = headers.get("origin").and_then(|v| v.to_str().ok()).unwrap_or("");
|
|
109
|
-
|
|
110
|
-
if origin.is_empty() || !is_origin_allowed(origin, &cors_config.allowed_origins) {
|
|
111
|
-
return Err(Box::new(
|
|
112
|
-
(
|
|
113
|
-
StatusCode::FORBIDDEN,
|
|
114
|
-
axum::Json(serde_json::json!({
|
|
115
|
-
"detail": format!("CORS request from origin '{}' not allowed", origin)
|
|
116
|
-
})),
|
|
117
|
-
)
|
|
118
|
-
.into_response(),
|
|
119
|
-
));
|
|
120
|
-
}
|
|
121
|
-
|
|
122
|
-
let requested_method = headers
|
|
123
|
-
.get("access-control-request-method")
|
|
124
|
-
.and_then(|v| v.to_str().ok())
|
|
125
|
-
.unwrap_or("");
|
|
126
|
-
|
|
127
|
-
if !requested_method.is_empty() && !is_method_allowed(requested_method, &cors_config.allowed_methods) {
|
|
128
|
-
return Err(Box::new((StatusCode::FORBIDDEN).into_response()));
|
|
129
|
-
}
|
|
130
|
-
|
|
131
|
-
let requested_headers_str = headers
|
|
132
|
-
.get("access-control-request-headers")
|
|
133
|
-
.and_then(|v| v.to_str().ok());
|
|
134
|
-
|
|
135
|
-
if let Some(req_headers) = requested_headers_str {
|
|
136
|
-
let requested_headers: Vec<&str> = req_headers.split(',').map(|h| h.trim()).collect();
|
|
137
|
-
|
|
138
|
-
if !are_headers_allowed(&requested_headers, &cors_config.allowed_headers) {
|
|
139
|
-
return Err(Box::new((StatusCode::FORBIDDEN).into_response()));
|
|
140
|
-
}
|
|
141
|
-
}
|
|
142
|
-
|
|
143
|
-
let mut response = Response::builder().status(StatusCode::NO_CONTENT);
|
|
144
|
-
|
|
145
|
-
let headers_mut = match response.headers_mut() {
|
|
146
|
-
Some(headers) => headers,
|
|
147
|
-
None => {
|
|
148
|
-
return Err(Box::new(
|
|
149
|
-
(
|
|
150
|
-
StatusCode::INTERNAL_SERVER_ERROR,
|
|
151
|
-
axum::Json(serde_json::json!({
|
|
152
|
-
"detail": "Failed to construct response headers"
|
|
153
|
-
})),
|
|
154
|
-
)
|
|
155
|
-
.into_response(),
|
|
156
|
-
));
|
|
157
|
-
}
|
|
158
|
-
};
|
|
159
|
-
|
|
160
|
-
headers_mut.insert(
|
|
161
|
-
"access-control-allow-origin",
|
|
162
|
-
HeaderValue::from_str(origin).unwrap_or_else(|_| HeaderValue::from_static("*")),
|
|
163
|
-
);
|
|
164
|
-
|
|
165
|
-
let methods = cors_config.allowed_methods.join(", ");
|
|
166
|
-
headers_mut.insert(
|
|
167
|
-
"access-control-allow-methods",
|
|
168
|
-
HeaderValue::from_str(&methods).unwrap_or_else(|_| HeaderValue::from_static("*")),
|
|
169
|
-
);
|
|
170
|
-
|
|
171
|
-
let allowed_headers = cors_config.allowed_headers.join(", ");
|
|
172
|
-
headers_mut.insert(
|
|
173
|
-
"access-control-allow-headers",
|
|
174
|
-
HeaderValue::from_str(&allowed_headers).unwrap_or_else(|_| HeaderValue::from_static("*")),
|
|
175
|
-
);
|
|
176
|
-
|
|
177
|
-
if let Some(max_age) = cors_config.max_age
|
|
178
|
-
&& let Ok(header_val) = HeaderValue::from_str(&max_age.to_string())
|
|
179
|
-
{
|
|
180
|
-
headers_mut.insert("access-control-max-age", header_val);
|
|
181
|
-
}
|
|
182
|
-
|
|
183
|
-
if let Some(true) = cors_config.allow_credentials {
|
|
184
|
-
headers_mut.insert("access-control-allow-credentials", HeaderValue::from_static("true"));
|
|
185
|
-
}
|
|
186
|
-
|
|
187
|
-
match response.body(Body::empty()) {
|
|
188
|
-
Ok(resp) => Ok(resp),
|
|
189
|
-
Err(_) => Err(Box::new(
|
|
190
|
-
(
|
|
191
|
-
StatusCode::INTERNAL_SERVER_ERROR,
|
|
192
|
-
axum::Json(serde_json::json!({
|
|
193
|
-
"detail": "Failed to construct response body"
|
|
194
|
-
})),
|
|
195
|
-
)
|
|
196
|
-
.into_response(),
|
|
197
|
-
)),
|
|
198
|
-
}
|
|
199
|
-
}
|
|
200
|
-
|
|
201
|
-
/// Add CORS headers to a successful response
|
|
202
|
-
///
|
|
203
|
-
/// Adds appropriate CORS headers to the response based on the configuration.
|
|
204
|
-
/// This function should be called for successful (non-error) responses to
|
|
205
|
-
/// cross-origin requests.
|
|
206
|
-
///
|
|
207
|
-
/// # Headers Added
|
|
208
|
-
///
|
|
209
|
-
/// - `Access-Control-Allow-Origin` - The origin that is allowed (if valid)
|
|
210
|
-
/// - `Access-Control-Expose-Headers` - Headers that are safe to expose to the client
|
|
211
|
-
/// - `Access-Control-Allow-Credentials` - "true" if credentials are allowed
|
|
212
|
-
///
|
|
213
|
-
/// # Arguments
|
|
214
|
-
/// * `response` - Mutable reference to the response to modify
|
|
215
|
-
/// * `origin` - The origin from the request (e.g., `<https://example.com>`)
|
|
216
|
-
/// * `cors_config` - CORS configuration to apply
|
|
217
|
-
///
|
|
218
|
-
/// # Example
|
|
219
|
-
///
|
|
220
|
-
/// ```ignore
|
|
221
|
-
/// let mut response = Response::new(Body::empty());
|
|
222
|
-
/// add_cors_headers(&mut response, "https://example.com", &cors_config);
|
|
223
|
-
/// ```
|
|
224
|
-
pub fn add_cors_headers(response: &mut Response<Body>, origin: &str, cors_config: &CorsConfig) {
|
|
225
|
-
let headers = response.headers_mut();
|
|
226
|
-
|
|
227
|
-
if let Ok(origin_value) = HeaderValue::from_str(origin) {
|
|
228
|
-
headers.insert("access-control-allow-origin", origin_value);
|
|
229
|
-
}
|
|
230
|
-
|
|
231
|
-
if let Some(ref expose_headers) = cors_config.expose_headers {
|
|
232
|
-
let expose = expose_headers.join(", ");
|
|
233
|
-
if let Ok(expose_value) = HeaderValue::from_str(&expose) {
|
|
234
|
-
headers.insert("access-control-expose-headers", expose_value);
|
|
235
|
-
}
|
|
236
|
-
}
|
|
237
|
-
|
|
238
|
-
if let Some(true) = cors_config.allow_credentials {
|
|
239
|
-
headers.insert("access-control-allow-credentials", HeaderValue::from_static("true"));
|
|
240
|
-
}
|
|
241
|
-
}
|
|
242
|
-
|
|
243
|
-
/// Validate a non-preflight CORS request
|
|
244
|
-
///
|
|
245
|
-
/// Checks if the Origin header is present and allowed for non-preflight (actual) requests.
|
|
246
|
-
/// Returns an error response if validation fails.
|
|
247
|
-
///
|
|
248
|
-
/// # Validation
|
|
249
|
-
///
|
|
250
|
-
/// - If no Origin header is present, the request is allowed (not a CORS request)
|
|
251
|
-
/// - If Origin header is present, it must match the allowed origins
|
|
252
|
-
///
|
|
253
|
-
/// # Arguments
|
|
254
|
-
/// * `headers` - Request headers containing origin information
|
|
255
|
-
/// * `cors_config` - CORS configuration to validate against
|
|
256
|
-
///
|
|
257
|
-
/// # Returns
|
|
258
|
-
/// * `Ok(())` - Request is allowed
|
|
259
|
-
/// * `Err(Response)` - 403 Forbidden with error details
|
|
260
|
-
///
|
|
261
|
-
/// # Note
|
|
262
|
-
///
|
|
263
|
-
/// This function is for actual requests, not OPTIONS preflight requests.
|
|
264
|
-
/// Use `handle_preflight` for OPTIONS requests.
|
|
265
|
-
pub fn validate_cors_request(headers: &HeaderMap, cors_config: &CorsConfig) -> Result<(), Box<Response<Body>>> {
|
|
266
|
-
let origin = headers.get("origin").and_then(|v| v.to_str().ok()).unwrap_or("");
|
|
267
|
-
|
|
268
|
-
if !origin.is_empty() && !is_origin_allowed(origin, &cors_config.allowed_origins) {
|
|
269
|
-
return Err(Box::new(
|
|
270
|
-
(
|
|
271
|
-
StatusCode::FORBIDDEN,
|
|
272
|
-
axum::Json(serde_json::json!({
|
|
273
|
-
"detail": format!("CORS request from origin '{}' not allowed", origin)
|
|
274
|
-
})),
|
|
275
|
-
)
|
|
276
|
-
.into_response(),
|
|
277
|
-
));
|
|
278
|
-
}
|
|
279
|
-
Ok(())
|
|
280
|
-
}
|
|
281
|
-
|
|
282
|
-
#[cfg(test)]
|
|
283
|
-
mod tests {
|
|
284
|
-
use super::*;
|
|
285
|
-
|
|
286
|
-
fn make_cors_config() -> CorsConfig {
|
|
287
|
-
CorsConfig {
|
|
288
|
-
allowed_origins: vec!["https://example.com".to_string()],
|
|
289
|
-
allowed_methods: vec!["GET".to_string(), "POST".to_string()],
|
|
290
|
-
allowed_headers: vec!["content-type".to_string(), "authorization".to_string()],
|
|
291
|
-
expose_headers: Some(vec!["x-custom-header".to_string()]),
|
|
292
|
-
max_age: Some(3600),
|
|
293
|
-
allow_credentials: Some(true),
|
|
294
|
-
}
|
|
295
|
-
}
|
|
296
|
-
|
|
297
|
-
#[test]
|
|
298
|
-
fn test_is_origin_allowed_exact_match() {
|
|
299
|
-
let allowed = vec!["https://example.com".to_string()];
|
|
300
|
-
assert!(is_origin_allowed("https://example.com", &allowed));
|
|
301
|
-
assert!(!is_origin_allowed("https://evil.com", &allowed));
|
|
302
|
-
}
|
|
303
|
-
|
|
304
|
-
#[test]
|
|
305
|
-
fn test_is_origin_allowed_wildcard() {
|
|
306
|
-
let allowed = vec!["*".to_string()];
|
|
307
|
-
assert!(is_origin_allowed("https://example.com", &allowed));
|
|
308
|
-
assert!(is_origin_allowed("https://any-domain.com", &allowed));
|
|
309
|
-
}
|
|
310
|
-
|
|
311
|
-
#[test]
|
|
312
|
-
fn test_is_origin_allowed_empty_origin() {
|
|
313
|
-
let allowed = vec!["*".to_string()];
|
|
314
|
-
assert!(!is_origin_allowed("", &allowed));
|
|
315
|
-
}
|
|
316
|
-
|
|
317
|
-
#[test]
|
|
318
|
-
fn test_is_method_allowed_case_insensitive() {
|
|
319
|
-
let allowed = vec!["GET".to_string(), "POST".to_string()];
|
|
320
|
-
assert!(is_method_allowed("GET", &allowed));
|
|
321
|
-
assert!(is_method_allowed("get", &allowed));
|
|
322
|
-
assert!(is_method_allowed("POST", &allowed));
|
|
323
|
-
assert!(is_method_allowed("post", &allowed));
|
|
324
|
-
assert!(!is_method_allowed("DELETE", &allowed));
|
|
325
|
-
}
|
|
326
|
-
|
|
327
|
-
#[test]
|
|
328
|
-
fn test_is_method_allowed_wildcard() {
|
|
329
|
-
let allowed = vec!["*".to_string()];
|
|
330
|
-
assert!(is_method_allowed("GET", &allowed));
|
|
331
|
-
assert!(is_method_allowed("DELETE", &allowed));
|
|
332
|
-
assert!(is_method_allowed("PATCH", &allowed));
|
|
333
|
-
}
|
|
334
|
-
|
|
335
|
-
#[test]
|
|
336
|
-
fn test_are_headers_allowed_case_insensitive() {
|
|
337
|
-
let allowed = vec!["Content-Type".to_string(), "Authorization".to_string()];
|
|
338
|
-
assert!(are_headers_allowed(&["content-type"], &allowed));
|
|
339
|
-
assert!(are_headers_allowed(&["AUTHORIZATION"], &allowed));
|
|
340
|
-
assert!(are_headers_allowed(&["content-type", "authorization"], &allowed));
|
|
341
|
-
assert!(!are_headers_allowed(&["x-custom"], &allowed));
|
|
342
|
-
}
|
|
343
|
-
|
|
344
|
-
#[test]
|
|
345
|
-
fn test_are_headers_allowed_wildcard() {
|
|
346
|
-
let allowed = vec!["*".to_string()];
|
|
347
|
-
assert!(are_headers_allowed(&["any-header"], &allowed));
|
|
348
|
-
assert!(are_headers_allowed(&["multiple", "headers"], &allowed));
|
|
349
|
-
}
|
|
350
|
-
|
|
351
|
-
#[test]
|
|
352
|
-
fn test_handle_preflight_success() {
|
|
353
|
-
let config = make_cors_config();
|
|
354
|
-
let mut headers = HeaderMap::new();
|
|
355
|
-
headers.insert("origin", HeaderValue::from_static("https://example.com"));
|
|
356
|
-
headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
|
|
357
|
-
headers.insert(
|
|
358
|
-
"access-control-request-headers",
|
|
359
|
-
HeaderValue::from_static("content-type"),
|
|
360
|
-
);
|
|
361
|
-
|
|
362
|
-
let result = handle_preflight(&headers, &config);
|
|
363
|
-
assert!(result.is_ok());
|
|
364
|
-
|
|
365
|
-
let response = result.unwrap();
|
|
366
|
-
assert_eq!(response.status(), StatusCode::NO_CONTENT);
|
|
367
|
-
|
|
368
|
-
let resp_headers = response.headers();
|
|
369
|
-
assert_eq!(
|
|
370
|
-
resp_headers.get("access-control-allow-origin").unwrap(),
|
|
371
|
-
"https://example.com"
|
|
372
|
-
);
|
|
373
|
-
assert!(
|
|
374
|
-
resp_headers
|
|
375
|
-
.get("access-control-allow-methods")
|
|
376
|
-
.unwrap()
|
|
377
|
-
.to_str()
|
|
378
|
-
.unwrap()
|
|
379
|
-
.contains("POST")
|
|
380
|
-
);
|
|
381
|
-
assert_eq!(resp_headers.get("access-control-max-age").unwrap(), "3600");
|
|
382
|
-
assert_eq!(resp_headers.get("access-control-allow-credentials").unwrap(), "true");
|
|
383
|
-
}
|
|
384
|
-
|
|
385
|
-
#[test]
|
|
386
|
-
fn test_handle_preflight_origin_not_allowed() {
|
|
387
|
-
let config = make_cors_config();
|
|
388
|
-
let mut headers = HeaderMap::new();
|
|
389
|
-
headers.insert("origin", HeaderValue::from_static("https://evil.com"));
|
|
390
|
-
headers.insert("access-control-request-method", HeaderValue::from_static("GET"));
|
|
391
|
-
|
|
392
|
-
let result = handle_preflight(&headers, &config);
|
|
393
|
-
assert!(result.is_err());
|
|
394
|
-
|
|
395
|
-
let response = *result.unwrap_err();
|
|
396
|
-
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
|
397
|
-
}
|
|
398
|
-
|
|
399
|
-
#[test]
|
|
400
|
-
fn test_handle_preflight_method_not_allowed() {
|
|
401
|
-
let config = make_cors_config();
|
|
402
|
-
let mut headers = HeaderMap::new();
|
|
403
|
-
headers.insert("origin", HeaderValue::from_static("https://example.com"));
|
|
404
|
-
headers.insert("access-control-request-method", HeaderValue::from_static("DELETE"));
|
|
405
|
-
|
|
406
|
-
let result = handle_preflight(&headers, &config);
|
|
407
|
-
assert!(result.is_err());
|
|
408
|
-
|
|
409
|
-
let response = *result.unwrap_err();
|
|
410
|
-
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
|
411
|
-
}
|
|
412
|
-
|
|
413
|
-
#[test]
|
|
414
|
-
fn test_handle_preflight_header_not_allowed() {
|
|
415
|
-
let config = make_cors_config();
|
|
416
|
-
let mut headers = HeaderMap::new();
|
|
417
|
-
headers.insert("origin", HeaderValue::from_static("https://example.com"));
|
|
418
|
-
headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
|
|
419
|
-
headers.insert(
|
|
420
|
-
"access-control-request-headers",
|
|
421
|
-
HeaderValue::from_static("x-forbidden-header"),
|
|
422
|
-
);
|
|
423
|
-
|
|
424
|
-
let result = handle_preflight(&headers, &config);
|
|
425
|
-
assert!(result.is_err());
|
|
426
|
-
|
|
427
|
-
let response = *result.unwrap_err();
|
|
428
|
-
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
|
429
|
-
}
|
|
430
|
-
|
|
431
|
-
#[test]
|
|
432
|
-
fn test_handle_preflight_empty_origin() {
|
|
433
|
-
let config = make_cors_config();
|
|
434
|
-
let headers = HeaderMap::new();
|
|
435
|
-
|
|
436
|
-
let result = handle_preflight(&headers, &config);
|
|
437
|
-
assert!(result.is_err());
|
|
438
|
-
|
|
439
|
-
let response = *result.unwrap_err();
|
|
440
|
-
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
|
441
|
-
}
|
|
442
|
-
|
|
443
|
-
#[test]
|
|
444
|
-
fn test_add_cors_headers() {
|
|
445
|
-
let config = make_cors_config();
|
|
446
|
-
let mut response = Response::new(Body::empty());
|
|
447
|
-
|
|
448
|
-
add_cors_headers(&mut response, "https://example.com", &config);
|
|
449
|
-
|
|
450
|
-
let headers = response.headers();
|
|
451
|
-
assert_eq!(
|
|
452
|
-
headers.get("access-control-allow-origin").unwrap(),
|
|
453
|
-
"https://example.com"
|
|
454
|
-
);
|
|
455
|
-
assert_eq!(headers.get("access-control-expose-headers").unwrap(), "x-custom-header");
|
|
456
|
-
assert_eq!(headers.get("access-control-allow-credentials").unwrap(), "true");
|
|
457
|
-
}
|
|
458
|
-
|
|
459
|
-
#[test]
|
|
460
|
-
fn test_validate_cors_request_allowed() {
|
|
461
|
-
let config = make_cors_config();
|
|
462
|
-
let mut headers = HeaderMap::new();
|
|
463
|
-
headers.insert("origin", HeaderValue::from_static("https://example.com"));
|
|
464
|
-
|
|
465
|
-
let result = validate_cors_request(&headers, &config);
|
|
466
|
-
assert!(result.is_ok());
|
|
467
|
-
}
|
|
468
|
-
|
|
469
|
-
#[test]
|
|
470
|
-
fn test_validate_cors_request_not_allowed() {
|
|
471
|
-
let config = make_cors_config();
|
|
472
|
-
let mut headers = HeaderMap::new();
|
|
473
|
-
headers.insert("origin", HeaderValue::from_static("https://evil.com"));
|
|
474
|
-
|
|
475
|
-
let result = validate_cors_request(&headers, &config);
|
|
476
|
-
assert!(result.is_err());
|
|
477
|
-
|
|
478
|
-
let response = *result.unwrap_err();
|
|
479
|
-
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
|
480
|
-
}
|
|
481
|
-
|
|
482
|
-
#[test]
|
|
483
|
-
fn test_validate_cors_request_no_origin() {
|
|
484
|
-
let config = make_cors_config();
|
|
485
|
-
let headers = HeaderMap::new();
|
|
486
|
-
|
|
487
|
-
let result = validate_cors_request(&headers, &config);
|
|
488
|
-
assert!(result.is_ok());
|
|
489
|
-
}
|
|
490
|
-
|
|
491
|
-
// SECURITY TESTS: CORS Attack Vectors
|
|
492
|
-
|
|
493
|
-
/// SECURITY TEST: Verify credentials=true with wildcard is caught
|
|
494
|
-
/// This is a critical vulnerability - RFC 6454 forbids this
|
|
495
|
-
#[test]
|
|
496
|
-
fn test_credentials_with_wildcard_config_is_security_issue() {
|
|
497
|
-
let config = CorsConfig {
|
|
498
|
-
allowed_origins: vec!["*".to_string()],
|
|
499
|
-
allowed_methods: vec!["GET".to_string()],
|
|
500
|
-
allowed_headers: vec![],
|
|
501
|
-
expose_headers: None,
|
|
502
|
-
max_age: None,
|
|
503
|
-
allow_credentials: Some(true), // SECURITY BUG: This should not be allowed with wildcard
|
|
504
|
-
};
|
|
505
|
-
|
|
506
|
-
let mut headers = HeaderMap::new();
|
|
507
|
-
headers.insert("origin", HeaderValue::from_static("https://evil.com"));
|
|
508
|
-
headers.insert("access-control-request-method", HeaderValue::from_static("GET"));
|
|
509
|
-
|
|
510
|
-
let result = handle_preflight(&headers, &config);
|
|
511
|
-
|
|
512
|
-
// BUG: This should return 500 or reject the config, but instead succeeds
|
|
513
|
-
if let Ok(response) = result {
|
|
514
|
-
let resp_headers = response.headers();
|
|
515
|
-
let has_credentials = resp_headers
|
|
516
|
-
.get("access-control-allow-credentials")
|
|
517
|
-
.map(|v| v == "true")
|
|
518
|
-
.unwrap_or(false);
|
|
519
|
-
let origin_header = resp_headers.get("access-control-allow-origin");
|
|
520
|
-
|
|
521
|
-
if has_credentials && origin_header.is_some() {
|
|
522
|
-
let origin_val = origin_header.unwrap().to_str().unwrap_or("");
|
|
523
|
-
if origin_val == "*" {
|
|
524
|
-
panic!("SECURITY VULNERABILITY: credentials=true with origin=* allowed");
|
|
525
|
-
}
|
|
526
|
-
}
|
|
527
|
-
}
|
|
528
|
-
}
|
|
529
|
-
|
|
530
|
-
/// SECURITY TEST: Exact origin matching required
|
|
531
|
-
/// Subdomain like api.evil.example.com must NOT match example.com
|
|
532
|
-
#[test]
|
|
533
|
-
fn test_subdomain_bypass_blocked() {
|
|
534
|
-
let config = CorsConfig {
|
|
535
|
-
allowed_origins: vec!["https://example.com".to_string()],
|
|
536
|
-
allowed_methods: vec!["GET".to_string()],
|
|
537
|
-
allowed_headers: vec![],
|
|
538
|
-
expose_headers: None,
|
|
539
|
-
max_age: None,
|
|
540
|
-
allow_credentials: None,
|
|
541
|
-
};
|
|
542
|
-
|
|
543
|
-
assert!(!is_origin_allowed("https://api.example.com", &config.allowed_origins));
|
|
544
|
-
assert!(!is_origin_allowed("https://evil.example.com", &config.allowed_origins));
|
|
545
|
-
assert!(!is_origin_allowed(
|
|
546
|
-
"https://sub.sub.example.com",
|
|
547
|
-
&config.allowed_origins
|
|
548
|
-
));
|
|
549
|
-
|
|
550
|
-
assert!(is_origin_allowed("https://example.com", &config.allowed_origins));
|
|
551
|
-
}
|
|
552
|
-
|
|
553
|
-
/// SECURITY TEST: Port exact matching required
|
|
554
|
-
/// localhost:3001 must NOT match localhost:3000
|
|
555
|
-
#[test]
|
|
556
|
-
fn test_port_bypass_blocked() {
|
|
557
|
-
let config = CorsConfig {
|
|
558
|
-
allowed_origins: vec!["http://localhost:3000".to_string()],
|
|
559
|
-
allowed_methods: vec!["GET".to_string()],
|
|
560
|
-
allowed_headers: vec![],
|
|
561
|
-
expose_headers: None,
|
|
562
|
-
max_age: None,
|
|
563
|
-
allow_credentials: None,
|
|
564
|
-
};
|
|
565
|
-
|
|
566
|
-
assert!(!is_origin_allowed("http://localhost:3001", &config.allowed_origins));
|
|
567
|
-
assert!(!is_origin_allowed("http://localhost:8080", &config.allowed_origins));
|
|
568
|
-
assert!(!is_origin_allowed("http://localhost:443", &config.allowed_origins));
|
|
569
|
-
|
|
570
|
-
assert!(is_origin_allowed("http://localhost:3000", &config.allowed_origins));
|
|
571
|
-
}
|
|
572
|
-
|
|
573
|
-
/// SECURITY TEST: Protocol exact matching required
|
|
574
|
-
/// http://example.com must NOT match https://example.com
|
|
575
|
-
#[test]
|
|
576
|
-
fn test_protocol_downgrade_attack_blocked() {
|
|
577
|
-
let config = CorsConfig {
|
|
578
|
-
allowed_origins: vec!["https://example.com".to_string()],
|
|
579
|
-
allowed_methods: vec!["GET".to_string()],
|
|
580
|
-
allowed_headers: vec![],
|
|
581
|
-
expose_headers: None,
|
|
582
|
-
max_age: None,
|
|
583
|
-
allow_credentials: None,
|
|
584
|
-
};
|
|
585
|
-
|
|
586
|
-
assert!(!is_origin_allowed("http://example.com", &config.allowed_origins));
|
|
587
|
-
assert!(!is_origin_allowed("ws://example.com", &config.allowed_origins));
|
|
588
|
-
assert!(!is_origin_allowed("wss://example.com", &config.allowed_origins));
|
|
589
|
-
|
|
590
|
-
assert!(is_origin_allowed("https://example.com", &config.allowed_origins));
|
|
591
|
-
}
|
|
592
|
-
|
|
593
|
-
/// SECURITY TEST: Case sensitivity in origin matching
|
|
594
|
-
/// Origins should match exactly (including case)
|
|
595
|
-
#[test]
|
|
596
|
-
fn test_case_sensitive_origin_matching() {
|
|
597
|
-
let config = CorsConfig {
|
|
598
|
-
allowed_origins: vec!["https://Example.Com".to_string()],
|
|
599
|
-
allowed_methods: vec!["GET".to_string()],
|
|
600
|
-
allowed_headers: vec![],
|
|
601
|
-
expose_headers: None,
|
|
602
|
-
max_age: None,
|
|
603
|
-
allow_credentials: None,
|
|
604
|
-
};
|
|
605
|
-
|
|
606
|
-
assert!(!is_origin_allowed("https://example.com", &config.allowed_origins));
|
|
607
|
-
assert!(!is_origin_allowed("https://EXAMPLE.COM", &config.allowed_origins));
|
|
608
|
-
|
|
609
|
-
assert!(is_origin_allowed("https://Example.Com", &config.allowed_origins));
|
|
610
|
-
}
|
|
611
|
-
|
|
612
|
-
/// SECURITY TEST: Trailing slash normalization
|
|
613
|
-
/// https://example.com/ should be treated differently from https://example.com
|
|
614
|
-
#[test]
|
|
615
|
-
fn test_trailing_slash_origin_not_normalized() {
|
|
616
|
-
let config = CorsConfig {
|
|
617
|
-
allowed_origins: vec!["https://example.com".to_string()],
|
|
618
|
-
allowed_methods: vec!["GET".to_string()],
|
|
619
|
-
allowed_headers: vec![],
|
|
620
|
-
expose_headers: None,
|
|
621
|
-
max_age: None,
|
|
622
|
-
allow_credentials: None,
|
|
623
|
-
};
|
|
624
|
-
|
|
625
|
-
assert!(!is_origin_allowed("https://example.com/", &config.allowed_origins));
|
|
626
|
-
|
|
627
|
-
assert!(is_origin_allowed("https://example.com", &config.allowed_origins));
|
|
628
|
-
}
|
|
629
|
-
|
|
630
|
-
/// SECURITY TEST: NULL origin and wildcard behavior
|
|
631
|
-
/// Special "null" origin used by file:// and sandboxed iframes
|
|
632
|
-
/// The current implementation treats "null" as a regular origin string,
|
|
633
|
-
/// which means it IS allowed by wildcard (not ideal but documents current behavior)
|
|
634
|
-
#[test]
|
|
635
|
-
fn test_null_origin_with_wildcard() {
|
|
636
|
-
let config = CorsConfig {
|
|
637
|
-
allowed_origins: vec!["*".to_string()],
|
|
638
|
-
allowed_methods: vec!["GET".to_string()],
|
|
639
|
-
allowed_headers: vec![],
|
|
640
|
-
expose_headers: None,
|
|
641
|
-
max_age: None,
|
|
642
|
-
allow_credentials: None,
|
|
643
|
-
};
|
|
644
|
-
|
|
645
|
-
// SECURITY NOTE: "null" origin is allowed by wildcard in current implementation
|
|
646
|
-
assert!(is_origin_allowed("null", &config.allowed_origins));
|
|
647
|
-
|
|
648
|
-
let with_explicit_null = CorsConfig {
|
|
649
|
-
allowed_origins: vec!["null".to_string()],
|
|
650
|
-
allowed_methods: vec!["GET".to_string()],
|
|
651
|
-
allowed_headers: vec![],
|
|
652
|
-
expose_headers: None,
|
|
653
|
-
max_age: None,
|
|
654
|
-
allow_credentials: None,
|
|
655
|
-
};
|
|
656
|
-
assert!(is_origin_allowed("null", &with_explicit_null.allowed_origins));
|
|
657
|
-
}
|
|
658
|
-
|
|
659
|
-
/// SECURITY TEST: Empty origin is always rejected
|
|
660
|
-
#[test]
|
|
661
|
-
fn test_empty_origin_always_rejected() {
|
|
662
|
-
let config_with_wildcard = CorsConfig {
|
|
663
|
-
allowed_origins: vec!["*".to_string()],
|
|
664
|
-
allowed_methods: vec!["GET".to_string()],
|
|
665
|
-
allowed_headers: vec![],
|
|
666
|
-
expose_headers: None,
|
|
667
|
-
max_age: None,
|
|
668
|
-
allow_credentials: None,
|
|
669
|
-
};
|
|
670
|
-
assert!(!is_origin_allowed("", &config_with_wildcard.allowed_origins));
|
|
671
|
-
|
|
672
|
-
let config_with_explicit = CorsConfig {
|
|
673
|
-
allowed_origins: vec!["https://example.com".to_string()],
|
|
674
|
-
allowed_methods: vec!["GET".to_string()],
|
|
675
|
-
allowed_headers: vec![],
|
|
676
|
-
expose_headers: None,
|
|
677
|
-
max_age: None,
|
|
678
|
-
allow_credentials: None,
|
|
679
|
-
};
|
|
680
|
-
assert!(!is_origin_allowed("", &config_with_explicit.allowed_origins));
|
|
681
|
-
}
|
|
682
|
-
|
|
683
|
-
/// SECURITY TEST: Preflight with invalid origin should reject
|
|
684
|
-
#[test]
|
|
685
|
-
fn test_preflight_rejects_invalid_origin() {
|
|
686
|
-
let config = make_cors_config();
|
|
687
|
-
let mut headers = HeaderMap::new();
|
|
688
|
-
headers.insert("origin", HeaderValue::from_static("https://untrusted.com"));
|
|
689
|
-
headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
|
|
690
|
-
|
|
691
|
-
let result = handle_preflight(&headers, &config);
|
|
692
|
-
assert!(result.is_err());
|
|
693
|
-
|
|
694
|
-
let response = *result.unwrap_err();
|
|
695
|
-
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
|
696
|
-
}
|
|
697
|
-
|
|
698
|
-
/// SECURITY TEST: Multiple origins - each must be exact match
|
|
699
|
-
#[test]
|
|
700
|
-
fn test_multiple_origins_exact_matching() {
|
|
701
|
-
let config = CorsConfig {
|
|
702
|
-
allowed_origins: vec!["https://trusted1.com".to_string(), "https://trusted2.com".to_string()],
|
|
703
|
-
allowed_methods: vec!["GET".to_string()],
|
|
704
|
-
allowed_headers: vec![],
|
|
705
|
-
expose_headers: None,
|
|
706
|
-
max_age: None,
|
|
707
|
-
allow_credentials: None,
|
|
708
|
-
};
|
|
709
|
-
|
|
710
|
-
assert!(is_origin_allowed("https://trusted1.com", &config.allowed_origins));
|
|
711
|
-
assert!(is_origin_allowed("https://trusted2.com", &config.allowed_origins));
|
|
712
|
-
|
|
713
|
-
assert!(!is_origin_allowed(
|
|
714
|
-
"https://trusted1.com.evil.com",
|
|
715
|
-
&config.allowed_origins
|
|
716
|
-
));
|
|
717
|
-
assert!(!is_origin_allowed("https://trusted3.com", &config.allowed_origins));
|
|
718
|
-
assert!(!is_origin_allowed("https://trusted.com", &config.allowed_origins));
|
|
719
|
-
}
|
|
720
|
-
|
|
721
|
-
/// SECURITY TEST: Wildcard origin should allow any origin (but check config)
|
|
722
|
-
#[test]
|
|
723
|
-
fn test_wildcard_allows_all_but_check_credentials() {
|
|
724
|
-
let config = CorsConfig {
|
|
725
|
-
allowed_origins: vec!["*".to_string()],
|
|
726
|
-
allowed_methods: vec!["GET".to_string()],
|
|
727
|
-
allowed_headers: vec![],
|
|
728
|
-
expose_headers: None,
|
|
729
|
-
max_age: None,
|
|
730
|
-
allow_credentials: None,
|
|
731
|
-
};
|
|
732
|
-
|
|
733
|
-
assert!(is_origin_allowed("https://example.com", &config.allowed_origins));
|
|
734
|
-
assert!(is_origin_allowed("https://evil.com", &config.allowed_origins));
|
|
735
|
-
assert!(is_origin_allowed("http://localhost:3000", &config.allowed_origins));
|
|
736
|
-
|
|
737
|
-
assert!(!is_origin_allowed("", &config.allowed_origins));
|
|
738
|
-
}
|
|
739
|
-
|
|
740
|
-
/// SECURITY TEST: Preflight response headers must match config exactly
|
|
741
|
-
#[test]
|
|
742
|
-
fn test_preflight_response_has_correct_allowed_origins() {
|
|
743
|
-
let config = CorsConfig {
|
|
744
|
-
allowed_origins: vec!["https://trusted.com".to_string()],
|
|
745
|
-
allowed_methods: vec!["GET".to_string(), "POST".to_string()],
|
|
746
|
-
allowed_headers: vec!["content-type".to_string()],
|
|
747
|
-
expose_headers: None,
|
|
748
|
-
max_age: Some(3600),
|
|
749
|
-
allow_credentials: Some(false),
|
|
750
|
-
};
|
|
751
|
-
|
|
752
|
-
let mut headers = HeaderMap::new();
|
|
753
|
-
headers.insert("origin", HeaderValue::from_static("https://trusted.com"));
|
|
754
|
-
headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
|
|
755
|
-
headers.insert(
|
|
756
|
-
"access-control-request-headers",
|
|
757
|
-
HeaderValue::from_static("content-type"),
|
|
758
|
-
);
|
|
759
|
-
|
|
760
|
-
let result = handle_preflight(&headers, &config);
|
|
761
|
-
assert!(result.is_ok());
|
|
762
|
-
|
|
763
|
-
let response = result.unwrap();
|
|
764
|
-
let resp_headers = response.headers();
|
|
765
|
-
|
|
766
|
-
assert_eq!(
|
|
767
|
-
resp_headers.get("access-control-allow-origin").unwrap(),
|
|
768
|
-
"https://trusted.com"
|
|
769
|
-
);
|
|
770
|
-
|
|
771
|
-
assert!(
|
|
772
|
-
resp_headers
|
|
773
|
-
.get("access-control-allow-methods")
|
|
774
|
-
.unwrap()
|
|
775
|
-
.to_str()
|
|
776
|
-
.unwrap()
|
|
777
|
-
.contains("GET")
|
|
778
|
-
);
|
|
779
|
-
assert!(
|
|
780
|
-
resp_headers
|
|
781
|
-
.get("access-control-allow-methods")
|
|
782
|
-
.unwrap()
|
|
783
|
-
.to_str()
|
|
784
|
-
.unwrap()
|
|
785
|
-
.contains("POST")
|
|
786
|
-
);
|
|
787
|
-
|
|
788
|
-
assert!(resp_headers.get("access-control-allow-credentials").is_none());
|
|
789
|
-
}
|
|
790
|
-
|
|
791
|
-
/// SECURITY TEST: Origin not in allowed list must be rejected in preflight
|
|
792
|
-
#[test]
|
|
793
|
-
fn test_preflight_all_origins_require_validation() {
|
|
794
|
-
let config = CorsConfig {
|
|
795
|
-
allowed_origins: vec!["https://trusted.com".to_string()],
|
|
796
|
-
allowed_methods: vec!["GET".to_string()],
|
|
797
|
-
allowed_headers: vec![],
|
|
798
|
-
expose_headers: None,
|
|
799
|
-
max_age: None,
|
|
800
|
-
allow_credentials: None,
|
|
801
|
-
};
|
|
802
|
-
|
|
803
|
-
let test_cases = vec![
|
|
804
|
-
"https://trusted.com",
|
|
805
|
-
"https://evil.com",
|
|
806
|
-
"https://trusted.com.evil",
|
|
807
|
-
"http://trusted.com",
|
|
808
|
-
"",
|
|
809
|
-
];
|
|
810
|
-
|
|
811
|
-
for origin in test_cases {
|
|
812
|
-
let mut headers = HeaderMap::new();
|
|
813
|
-
headers.insert(
|
|
814
|
-
"origin",
|
|
815
|
-
HeaderValue::from_str(origin).unwrap_or_else(|_| HeaderValue::from_static("")),
|
|
816
|
-
);
|
|
817
|
-
headers.insert("access-control-request-method", HeaderValue::from_static("GET"));
|
|
818
|
-
|
|
819
|
-
let result = handle_preflight(&headers, &config);
|
|
820
|
-
|
|
821
|
-
if origin == "https://trusted.com" {
|
|
822
|
-
assert!(result.is_ok(), "Valid origin {} should be allowed", origin);
|
|
823
|
-
} else {
|
|
824
|
-
assert!(result.is_err(), "Invalid origin {} should be rejected", origin);
|
|
825
|
-
}
|
|
826
|
-
}
|
|
827
|
-
}
|
|
828
|
-
|
|
829
|
-
/// SECURITY TEST: Requested headers must be in allowed list
|
|
830
|
-
#[test]
|
|
831
|
-
fn test_preflight_validates_all_requested_headers() {
|
|
832
|
-
let config = CorsConfig {
|
|
833
|
-
allowed_origins: vec!["https://trusted.com".to_string()],
|
|
834
|
-
allowed_methods: vec!["POST".to_string()],
|
|
835
|
-
allowed_headers: vec!["content-type".to_string(), "authorization".to_string()],
|
|
836
|
-
expose_headers: None,
|
|
837
|
-
max_age: None,
|
|
838
|
-
allow_credentials: None,
|
|
839
|
-
};
|
|
840
|
-
|
|
841
|
-
let test_cases = vec![
|
|
842
|
-
("content-type", true),
|
|
843
|
-
("authorization", true),
|
|
844
|
-
("content-type, authorization", true),
|
|
845
|
-
("x-custom-header", false),
|
|
846
|
-
("content-type, x-custom", false),
|
|
847
|
-
];
|
|
848
|
-
|
|
849
|
-
for (headers_str, should_pass) in test_cases {
|
|
850
|
-
let mut headers = HeaderMap::new();
|
|
851
|
-
headers.insert("origin", HeaderValue::from_static("https://trusted.com"));
|
|
852
|
-
headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
|
|
853
|
-
headers.insert(
|
|
854
|
-
"access-control-request-headers",
|
|
855
|
-
HeaderValue::from_str(headers_str).unwrap(),
|
|
856
|
-
);
|
|
857
|
-
|
|
858
|
-
let result = handle_preflight(&headers, &config);
|
|
859
|
-
|
|
860
|
-
if should_pass {
|
|
861
|
-
assert!(
|
|
862
|
-
result.is_ok(),
|
|
863
|
-
"Preflight with valid headers '{}' should pass",
|
|
864
|
-
headers_str
|
|
865
|
-
);
|
|
866
|
-
} else {
|
|
867
|
-
assert!(
|
|
868
|
-
result.is_err(),
|
|
869
|
-
"Preflight with invalid headers '{}' should fail",
|
|
870
|
-
headers_str
|
|
871
|
-
);
|
|
872
|
-
}
|
|
873
|
-
}
|
|
874
|
-
}
|
|
875
|
-
|
|
876
|
-
/// SECURITY TEST: add_cors_headers should respect origin validation
|
|
877
|
-
#[test]
|
|
878
|
-
fn test_add_cors_headers_respects_origin() {
|
|
879
|
-
let config = CorsConfig {
|
|
880
|
-
allowed_origins: vec!["https://trusted.com".to_string()],
|
|
881
|
-
allowed_methods: vec!["GET".to_string()],
|
|
882
|
-
allowed_headers: vec![],
|
|
883
|
-
expose_headers: Some(vec!["x-custom".to_string()]),
|
|
884
|
-
max_age: None,
|
|
885
|
-
allow_credentials: Some(true),
|
|
886
|
-
};
|
|
887
|
-
|
|
888
|
-
let mut response = Response::new(Body::empty());
|
|
889
|
-
|
|
890
|
-
add_cors_headers(&mut response, "https://trusted.com", &config);
|
|
891
|
-
|
|
892
|
-
let headers = response.headers();
|
|
893
|
-
assert_eq!(
|
|
894
|
-
headers.get("access-control-allow-origin").unwrap(),
|
|
895
|
-
"https://trusted.com"
|
|
896
|
-
);
|
|
897
|
-
assert_eq!(headers.get("access-control-expose-headers").unwrap(), "x-custom");
|
|
898
|
-
assert_eq!(headers.get("access-control-allow-credentials").unwrap(), "true");
|
|
899
|
-
}
|
|
900
|
-
|
|
901
|
-
/// SECURITY TEST: validate_cors_request respects allowed origins
|
|
902
|
-
#[test]
|
|
903
|
-
fn test_validate_cors_request_origin_must_match() {
|
|
904
|
-
let config = CorsConfig {
|
|
905
|
-
allowed_origins: vec!["https://trusted.com".to_string()],
|
|
906
|
-
allowed_methods: vec!["GET".to_string()],
|
|
907
|
-
allowed_headers: vec![],
|
|
908
|
-
expose_headers: None,
|
|
909
|
-
max_age: None,
|
|
910
|
-
allow_credentials: None,
|
|
911
|
-
};
|
|
912
|
-
|
|
913
|
-
let mut headers = HeaderMap::new();
|
|
914
|
-
headers.insert("origin", HeaderValue::from_static("https://trusted.com"));
|
|
915
|
-
assert!(validate_cors_request(&headers, &config).is_ok());
|
|
916
|
-
|
|
917
|
-
let mut headers = HeaderMap::new();
|
|
918
|
-
headers.insert("origin", HeaderValue::from_static("https://evil.com"));
|
|
919
|
-
assert!(validate_cors_request(&headers, &config).is_err());
|
|
920
|
-
|
|
921
|
-
let headers = HeaderMap::new();
|
|
922
|
-
assert!(validate_cors_request(&headers, &config).is_ok());
|
|
923
|
-
}
|
|
924
|
-
|
|
925
|
-
/// SECURITY TEST: Preflight without requested method should fail
|
|
926
|
-
#[test]
|
|
927
|
-
fn test_preflight_requires_access_control_request_method() {
|
|
928
|
-
let config = make_cors_config();
|
|
929
|
-
let mut headers = HeaderMap::new();
|
|
930
|
-
headers.insert("origin", HeaderValue::from_static("https://example.com"));
|
|
931
|
-
|
|
932
|
-
let result = handle_preflight(&headers, &config);
|
|
933
|
-
assert!(result.is_ok());
|
|
934
|
-
}
|
|
935
|
-
|
|
936
|
-
/// SECURITY TEST: Case-insensitive method matching
|
|
937
|
-
#[test]
|
|
938
|
-
fn test_preflight_method_case_insensitive() {
|
|
939
|
-
let config = CorsConfig {
|
|
940
|
-
allowed_origins: vec!["https://example.com".to_string()],
|
|
941
|
-
allowed_methods: vec!["GET".to_string(), "POST".to_string()],
|
|
942
|
-
allowed_headers: vec![],
|
|
943
|
-
expose_headers: None,
|
|
944
|
-
max_age: None,
|
|
945
|
-
allow_credentials: None,
|
|
946
|
-
};
|
|
947
|
-
|
|
948
|
-
let test_cases = vec!["GET", "get", "Get", "POST", "post"];
|
|
949
|
-
|
|
950
|
-
for method in test_cases {
|
|
951
|
-
let mut headers = HeaderMap::new();
|
|
952
|
-
headers.insert("origin", HeaderValue::from_static("https://example.com"));
|
|
953
|
-
headers.insert("access-control-request-method", HeaderValue::from_str(method).unwrap());
|
|
954
|
-
|
|
955
|
-
let result = handle_preflight(&headers, &config);
|
|
956
|
-
assert!(
|
|
957
|
-
result.is_ok(),
|
|
958
|
-
"Method '{}' should be allowed (case-insensitive)",
|
|
959
|
-
method
|
|
960
|
-
);
|
|
961
|
-
}
|
|
962
|
-
}
|
|
963
|
-
|
|
964
|
-
/// SECURITY TEST: Ensure preflight max-age is set correctly
|
|
965
|
-
#[test]
|
|
966
|
-
fn test_preflight_max_age_header() {
|
|
967
|
-
let config = CorsConfig {
|
|
968
|
-
allowed_origins: vec!["https://example.com".to_string()],
|
|
969
|
-
allowed_methods: vec!["GET".to_string()],
|
|
970
|
-
allowed_headers: vec![],
|
|
971
|
-
expose_headers: None,
|
|
972
|
-
max_age: Some(7200),
|
|
973
|
-
allow_credentials: None,
|
|
974
|
-
};
|
|
975
|
-
|
|
976
|
-
let mut headers = HeaderMap::new();
|
|
977
|
-
headers.insert("origin", HeaderValue::from_static("https://example.com"));
|
|
978
|
-
headers.insert("access-control-request-method", HeaderValue::from_static("GET"));
|
|
979
|
-
|
|
980
|
-
let result = handle_preflight(&headers, &config);
|
|
981
|
-
assert!(result.is_ok());
|
|
982
|
-
|
|
983
|
-
let response = result.unwrap();
|
|
984
|
-
assert_eq!(response.headers().get("access-control-max-age").unwrap(), "7200");
|
|
985
|
-
}
|
|
986
|
-
|
|
987
|
-
/// SECURITY TEST: Wildcard partial patterns should not work
|
|
988
|
-
/// *.example.com style patterns are not supported (good!)
|
|
989
|
-
#[test]
|
|
990
|
-
fn test_wildcard_patterns_not_supported() {
|
|
991
|
-
let config = CorsConfig {
|
|
992
|
-
allowed_origins: vec!["*.example.com".to_string()],
|
|
993
|
-
allowed_methods: vec!["GET".to_string()],
|
|
994
|
-
allowed_headers: vec![],
|
|
995
|
-
expose_headers: None,
|
|
996
|
-
max_age: None,
|
|
997
|
-
allow_credentials: None,
|
|
998
|
-
};
|
|
999
|
-
|
|
1000
|
-
assert!(!is_origin_allowed("https://api.example.com", &config.allowed_origins));
|
|
1001
|
-
assert!(!is_origin_allowed("https://example.com", &config.allowed_origins));
|
|
1002
|
-
|
|
1003
|
-
assert!(is_origin_allowed("*.example.com", &config.allowed_origins));
|
|
1004
|
-
}
|
|
1005
|
-
}
|
|
1
|
+
//! CORS (Cross-Origin Resource Sharing) handling
|
|
2
|
+
//!
|
|
3
|
+
//! Handles CORS preflight requests and adds CORS headers to responses
|
|
4
|
+
|
|
5
|
+
use crate::CorsConfig;
|
|
6
|
+
use axum::body::Body;
|
|
7
|
+
use axum::http::{HeaderMap, HeaderValue, Response, StatusCode};
|
|
8
|
+
use axum::response::IntoResponse;
|
|
9
|
+
|
|
10
|
+
/// Check if an origin is allowed by the CORS configuration
|
|
11
|
+
///
|
|
12
|
+
/// Supports exact matches and wildcard ("*") for any origin.
|
|
13
|
+
/// Empty origins always return false for security.
|
|
14
|
+
///
|
|
15
|
+
/// # Arguments
|
|
16
|
+
/// * `origin` - The origin string from the HTTP request (e.g., "https://example.com")
|
|
17
|
+
/// * `allowed_origins` - List of allowed origins configured for CORS
|
|
18
|
+
///
|
|
19
|
+
/// # Returns
|
|
20
|
+
/// `true` if the origin is allowed, `false` otherwise
|
|
21
|
+
fn is_origin_allowed(origin: &str, allowed_origins: &[String]) -> bool {
|
|
22
|
+
if origin.is_empty() {
|
|
23
|
+
return false;
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
allowed_origins
|
|
27
|
+
.iter()
|
|
28
|
+
.any(|allowed| allowed == "*" || allowed == origin)
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
/// Check if a method is allowed by the CORS configuration
|
|
32
|
+
///
|
|
33
|
+
/// Supports exact matches and wildcard ("*") for any method.
|
|
34
|
+
/// Comparison is case-insensitive (e.g., "get" matches "GET").
|
|
35
|
+
///
|
|
36
|
+
/// # Arguments
|
|
37
|
+
/// * `method` - The HTTP method requested (e.g., "GET", "POST")
|
|
38
|
+
/// * `allowed_methods` - List of allowed HTTP methods configured for CORS
|
|
39
|
+
///
|
|
40
|
+
/// # Returns
|
|
41
|
+
/// `true` if the method is allowed, `false` otherwise
|
|
42
|
+
fn is_method_allowed(method: &str, allowed_methods: &[String]) -> bool {
|
|
43
|
+
allowed_methods
|
|
44
|
+
.iter()
|
|
45
|
+
.any(|allowed| allowed == "*" || allowed.eq_ignore_ascii_case(method))
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
/// Check if all requested headers are allowed by CORS configuration
|
|
49
|
+
///
|
|
50
|
+
/// Headers are case-insensitive. Supports wildcard ("*") for allowing any header.
|
|
51
|
+
/// If a wildcard is configured, all requested headers are allowed.
|
|
52
|
+
///
|
|
53
|
+
/// # Arguments
|
|
54
|
+
/// * `requested` - Array of header names requested by the client
|
|
55
|
+
/// * `allowed` - List of allowed header names configured for CORS
|
|
56
|
+
///
|
|
57
|
+
/// # Returns
|
|
58
|
+
/// `true` if all requested headers are allowed, `false` if any header is not allowed
|
|
59
|
+
fn are_headers_allowed(requested: &[&str], allowed: &[String]) -> bool {
|
|
60
|
+
if allowed.iter().any(|h| h == "*") {
|
|
61
|
+
return true;
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
requested.iter().all(|req_header| {
|
|
65
|
+
allowed
|
|
66
|
+
.iter()
|
|
67
|
+
.any(|allowed_header| allowed_header.eq_ignore_ascii_case(req_header))
|
|
68
|
+
})
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
/// Handle CORS preflight (OPTIONS) request
|
|
72
|
+
///
|
|
73
|
+
/// Validates the request against the CORS configuration and returns appropriate
|
|
74
|
+
/// response or error. This function processes OPTIONS requests as defined in the
|
|
75
|
+
/// CORS specification (RFC 7231).
|
|
76
|
+
///
|
|
77
|
+
/// # Validation
|
|
78
|
+
///
|
|
79
|
+
/// Checks the following conditions:
|
|
80
|
+
/// 1. **Origin Header:** Must be present and match configured allowed origins
|
|
81
|
+
/// 2. **Access-Control-Request-Method:** Must match configured allowed methods
|
|
82
|
+
/// 3. **Access-Control-Request-Headers:** All requested headers must match configured allowed headers
|
|
83
|
+
///
|
|
84
|
+
/// # Success Response
|
|
85
|
+
///
|
|
86
|
+
/// Returns HTTP 204 (No Content) with the following response headers:
|
|
87
|
+
/// - `Access-Control-Allow-Origin` - The origin that is allowed
|
|
88
|
+
/// - `Access-Control-Allow-Methods` - Comma-separated list of allowed methods
|
|
89
|
+
/// - `Access-Control-Allow-Headers` - Comma-separated list of allowed headers
|
|
90
|
+
/// - `Access-Control-Max-Age` - Caching duration in seconds (if configured)
|
|
91
|
+
/// - `Access-Control-Allow-Credentials` - "true" if credentials are allowed
|
|
92
|
+
///
|
|
93
|
+
/// # Error Response
|
|
94
|
+
///
|
|
95
|
+
/// Returns HTTP 403 (Forbidden) if validation fails for:
|
|
96
|
+
/// - Origin not in allowed list
|
|
97
|
+
/// - Requested method not allowed
|
|
98
|
+
/// - Requested headers not allowed
|
|
99
|
+
///
|
|
100
|
+
/// # Arguments
|
|
101
|
+
/// * `headers` - Request headers containing CORS preflight information
|
|
102
|
+
/// * `cors_config` - CORS configuration to validate against
|
|
103
|
+
///
|
|
104
|
+
/// # Returns
|
|
105
|
+
/// * `Ok(Response)` - 204 No Content with CORS headers
|
|
106
|
+
/// * `Err(Response)` - 403 Forbidden or 500 Internal Server Error
|
|
107
|
+
pub fn handle_preflight(headers: &HeaderMap, cors_config: &CorsConfig) -> Result<Response<Body>, Box<Response<Body>>> {
|
|
108
|
+
let origin = headers.get("origin").and_then(|v| v.to_str().ok()).unwrap_or("");
|
|
109
|
+
|
|
110
|
+
if origin.is_empty() || !is_origin_allowed(origin, &cors_config.allowed_origins) {
|
|
111
|
+
return Err(Box::new(
|
|
112
|
+
(
|
|
113
|
+
StatusCode::FORBIDDEN,
|
|
114
|
+
axum::Json(serde_json::json!({
|
|
115
|
+
"detail": format!("CORS request from origin '{}' not allowed", origin)
|
|
116
|
+
})),
|
|
117
|
+
)
|
|
118
|
+
.into_response(),
|
|
119
|
+
));
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
let requested_method = headers
|
|
123
|
+
.get("access-control-request-method")
|
|
124
|
+
.and_then(|v| v.to_str().ok())
|
|
125
|
+
.unwrap_or("");
|
|
126
|
+
|
|
127
|
+
if !requested_method.is_empty() && !is_method_allowed(requested_method, &cors_config.allowed_methods) {
|
|
128
|
+
return Err(Box::new((StatusCode::FORBIDDEN).into_response()));
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
let requested_headers_str = headers
|
|
132
|
+
.get("access-control-request-headers")
|
|
133
|
+
.and_then(|v| v.to_str().ok());
|
|
134
|
+
|
|
135
|
+
if let Some(req_headers) = requested_headers_str {
|
|
136
|
+
let requested_headers: Vec<&str> = req_headers.split(',').map(|h| h.trim()).collect();
|
|
137
|
+
|
|
138
|
+
if !are_headers_allowed(&requested_headers, &cors_config.allowed_headers) {
|
|
139
|
+
return Err(Box::new((StatusCode::FORBIDDEN).into_response()));
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
let mut response = Response::builder().status(StatusCode::NO_CONTENT);
|
|
144
|
+
|
|
145
|
+
let headers_mut = match response.headers_mut() {
|
|
146
|
+
Some(headers) => headers,
|
|
147
|
+
None => {
|
|
148
|
+
return Err(Box::new(
|
|
149
|
+
(
|
|
150
|
+
StatusCode::INTERNAL_SERVER_ERROR,
|
|
151
|
+
axum::Json(serde_json::json!({
|
|
152
|
+
"detail": "Failed to construct response headers"
|
|
153
|
+
})),
|
|
154
|
+
)
|
|
155
|
+
.into_response(),
|
|
156
|
+
));
|
|
157
|
+
}
|
|
158
|
+
};
|
|
159
|
+
|
|
160
|
+
headers_mut.insert(
|
|
161
|
+
"access-control-allow-origin",
|
|
162
|
+
HeaderValue::from_str(origin).unwrap_or_else(|_| HeaderValue::from_static("*")),
|
|
163
|
+
);
|
|
164
|
+
|
|
165
|
+
let methods = cors_config.allowed_methods.join(", ");
|
|
166
|
+
headers_mut.insert(
|
|
167
|
+
"access-control-allow-methods",
|
|
168
|
+
HeaderValue::from_str(&methods).unwrap_or_else(|_| HeaderValue::from_static("*")),
|
|
169
|
+
);
|
|
170
|
+
|
|
171
|
+
let allowed_headers = cors_config.allowed_headers.join(", ");
|
|
172
|
+
headers_mut.insert(
|
|
173
|
+
"access-control-allow-headers",
|
|
174
|
+
HeaderValue::from_str(&allowed_headers).unwrap_or_else(|_| HeaderValue::from_static("*")),
|
|
175
|
+
);
|
|
176
|
+
|
|
177
|
+
if let Some(max_age) = cors_config.max_age
|
|
178
|
+
&& let Ok(header_val) = HeaderValue::from_str(&max_age.to_string())
|
|
179
|
+
{
|
|
180
|
+
headers_mut.insert("access-control-max-age", header_val);
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
if let Some(true) = cors_config.allow_credentials {
|
|
184
|
+
headers_mut.insert("access-control-allow-credentials", HeaderValue::from_static("true"));
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
match response.body(Body::empty()) {
|
|
188
|
+
Ok(resp) => Ok(resp),
|
|
189
|
+
Err(_) => Err(Box::new(
|
|
190
|
+
(
|
|
191
|
+
StatusCode::INTERNAL_SERVER_ERROR,
|
|
192
|
+
axum::Json(serde_json::json!({
|
|
193
|
+
"detail": "Failed to construct response body"
|
|
194
|
+
})),
|
|
195
|
+
)
|
|
196
|
+
.into_response(),
|
|
197
|
+
)),
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
/// Add CORS headers to a successful response
|
|
202
|
+
///
|
|
203
|
+
/// Adds appropriate CORS headers to the response based on the configuration.
|
|
204
|
+
/// This function should be called for successful (non-error) responses to
|
|
205
|
+
/// cross-origin requests.
|
|
206
|
+
///
|
|
207
|
+
/// # Headers Added
|
|
208
|
+
///
|
|
209
|
+
/// - `Access-Control-Allow-Origin` - The origin that is allowed (if valid)
|
|
210
|
+
/// - `Access-Control-Expose-Headers` - Headers that are safe to expose to the client
|
|
211
|
+
/// - `Access-Control-Allow-Credentials` - "true" if credentials are allowed
|
|
212
|
+
///
|
|
213
|
+
/// # Arguments
|
|
214
|
+
/// * `response` - Mutable reference to the response to modify
|
|
215
|
+
/// * `origin` - The origin from the request (e.g., `<https://example.com>`)
|
|
216
|
+
/// * `cors_config` - CORS configuration to apply
|
|
217
|
+
///
|
|
218
|
+
/// # Example
|
|
219
|
+
///
|
|
220
|
+
/// ```ignore
|
|
221
|
+
/// let mut response = Response::new(Body::empty());
|
|
222
|
+
/// add_cors_headers(&mut response, "https://example.com", &cors_config);
|
|
223
|
+
/// ```
|
|
224
|
+
pub fn add_cors_headers(response: &mut Response<Body>, origin: &str, cors_config: &CorsConfig) {
|
|
225
|
+
let headers = response.headers_mut();
|
|
226
|
+
|
|
227
|
+
if let Ok(origin_value) = HeaderValue::from_str(origin) {
|
|
228
|
+
headers.insert("access-control-allow-origin", origin_value);
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
if let Some(ref expose_headers) = cors_config.expose_headers {
|
|
232
|
+
let expose = expose_headers.join(", ");
|
|
233
|
+
if let Ok(expose_value) = HeaderValue::from_str(&expose) {
|
|
234
|
+
headers.insert("access-control-expose-headers", expose_value);
|
|
235
|
+
}
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
if let Some(true) = cors_config.allow_credentials {
|
|
239
|
+
headers.insert("access-control-allow-credentials", HeaderValue::from_static("true"));
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
/// Validate a non-preflight CORS request
|
|
244
|
+
///
|
|
245
|
+
/// Checks if the Origin header is present and allowed for non-preflight (actual) requests.
|
|
246
|
+
/// Returns an error response if validation fails.
|
|
247
|
+
///
|
|
248
|
+
/// # Validation
|
|
249
|
+
///
|
|
250
|
+
/// - If no Origin header is present, the request is allowed (not a CORS request)
|
|
251
|
+
/// - If Origin header is present, it must match the allowed origins
|
|
252
|
+
///
|
|
253
|
+
/// # Arguments
|
|
254
|
+
/// * `headers` - Request headers containing origin information
|
|
255
|
+
/// * `cors_config` - CORS configuration to validate against
|
|
256
|
+
///
|
|
257
|
+
/// # Returns
|
|
258
|
+
/// * `Ok(())` - Request is allowed
|
|
259
|
+
/// * `Err(Response)` - 403 Forbidden with error details
|
|
260
|
+
///
|
|
261
|
+
/// # Note
|
|
262
|
+
///
|
|
263
|
+
/// This function is for actual requests, not OPTIONS preflight requests.
|
|
264
|
+
/// Use `handle_preflight` for OPTIONS requests.
|
|
265
|
+
pub fn validate_cors_request(headers: &HeaderMap, cors_config: &CorsConfig) -> Result<(), Box<Response<Body>>> {
|
|
266
|
+
let origin = headers.get("origin").and_then(|v| v.to_str().ok()).unwrap_or("");
|
|
267
|
+
|
|
268
|
+
if !origin.is_empty() && !is_origin_allowed(origin, &cors_config.allowed_origins) {
|
|
269
|
+
return Err(Box::new(
|
|
270
|
+
(
|
|
271
|
+
StatusCode::FORBIDDEN,
|
|
272
|
+
axum::Json(serde_json::json!({
|
|
273
|
+
"detail": format!("CORS request from origin '{}' not allowed", origin)
|
|
274
|
+
})),
|
|
275
|
+
)
|
|
276
|
+
.into_response(),
|
|
277
|
+
));
|
|
278
|
+
}
|
|
279
|
+
Ok(())
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
#[cfg(test)]
|
|
283
|
+
mod tests {
|
|
284
|
+
use super::*;
|
|
285
|
+
|
|
286
|
+
fn make_cors_config() -> CorsConfig {
|
|
287
|
+
CorsConfig {
|
|
288
|
+
allowed_origins: vec!["https://example.com".to_string()],
|
|
289
|
+
allowed_methods: vec!["GET".to_string(), "POST".to_string()],
|
|
290
|
+
allowed_headers: vec!["content-type".to_string(), "authorization".to_string()],
|
|
291
|
+
expose_headers: Some(vec!["x-custom-header".to_string()]),
|
|
292
|
+
max_age: Some(3600),
|
|
293
|
+
allow_credentials: Some(true),
|
|
294
|
+
}
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
#[test]
|
|
298
|
+
fn test_is_origin_allowed_exact_match() {
|
|
299
|
+
let allowed = vec!["https://example.com".to_string()];
|
|
300
|
+
assert!(is_origin_allowed("https://example.com", &allowed));
|
|
301
|
+
assert!(!is_origin_allowed("https://evil.com", &allowed));
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
#[test]
|
|
305
|
+
fn test_is_origin_allowed_wildcard() {
|
|
306
|
+
let allowed = vec!["*".to_string()];
|
|
307
|
+
assert!(is_origin_allowed("https://example.com", &allowed));
|
|
308
|
+
assert!(is_origin_allowed("https://any-domain.com", &allowed));
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
#[test]
|
|
312
|
+
fn test_is_origin_allowed_empty_origin() {
|
|
313
|
+
let allowed = vec!["*".to_string()];
|
|
314
|
+
assert!(!is_origin_allowed("", &allowed));
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
#[test]
|
|
318
|
+
fn test_is_method_allowed_case_insensitive() {
|
|
319
|
+
let allowed = vec!["GET".to_string(), "POST".to_string()];
|
|
320
|
+
assert!(is_method_allowed("GET", &allowed));
|
|
321
|
+
assert!(is_method_allowed("get", &allowed));
|
|
322
|
+
assert!(is_method_allowed("POST", &allowed));
|
|
323
|
+
assert!(is_method_allowed("post", &allowed));
|
|
324
|
+
assert!(!is_method_allowed("DELETE", &allowed));
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
#[test]
|
|
328
|
+
fn test_is_method_allowed_wildcard() {
|
|
329
|
+
let allowed = vec!["*".to_string()];
|
|
330
|
+
assert!(is_method_allowed("GET", &allowed));
|
|
331
|
+
assert!(is_method_allowed("DELETE", &allowed));
|
|
332
|
+
assert!(is_method_allowed("PATCH", &allowed));
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
#[test]
|
|
336
|
+
fn test_are_headers_allowed_case_insensitive() {
|
|
337
|
+
let allowed = vec!["Content-Type".to_string(), "Authorization".to_string()];
|
|
338
|
+
assert!(are_headers_allowed(&["content-type"], &allowed));
|
|
339
|
+
assert!(are_headers_allowed(&["AUTHORIZATION"], &allowed));
|
|
340
|
+
assert!(are_headers_allowed(&["content-type", "authorization"], &allowed));
|
|
341
|
+
assert!(!are_headers_allowed(&["x-custom"], &allowed));
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
#[test]
|
|
345
|
+
fn test_are_headers_allowed_wildcard() {
|
|
346
|
+
let allowed = vec!["*".to_string()];
|
|
347
|
+
assert!(are_headers_allowed(&["any-header"], &allowed));
|
|
348
|
+
assert!(are_headers_allowed(&["multiple", "headers"], &allowed));
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
#[test]
|
|
352
|
+
fn test_handle_preflight_success() {
|
|
353
|
+
let config = make_cors_config();
|
|
354
|
+
let mut headers = HeaderMap::new();
|
|
355
|
+
headers.insert("origin", HeaderValue::from_static("https://example.com"));
|
|
356
|
+
headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
|
|
357
|
+
headers.insert(
|
|
358
|
+
"access-control-request-headers",
|
|
359
|
+
HeaderValue::from_static("content-type"),
|
|
360
|
+
);
|
|
361
|
+
|
|
362
|
+
let result = handle_preflight(&headers, &config);
|
|
363
|
+
assert!(result.is_ok());
|
|
364
|
+
|
|
365
|
+
let response = result.unwrap();
|
|
366
|
+
assert_eq!(response.status(), StatusCode::NO_CONTENT);
|
|
367
|
+
|
|
368
|
+
let resp_headers = response.headers();
|
|
369
|
+
assert_eq!(
|
|
370
|
+
resp_headers.get("access-control-allow-origin").unwrap(),
|
|
371
|
+
"https://example.com"
|
|
372
|
+
);
|
|
373
|
+
assert!(
|
|
374
|
+
resp_headers
|
|
375
|
+
.get("access-control-allow-methods")
|
|
376
|
+
.unwrap()
|
|
377
|
+
.to_str()
|
|
378
|
+
.unwrap()
|
|
379
|
+
.contains("POST")
|
|
380
|
+
);
|
|
381
|
+
assert_eq!(resp_headers.get("access-control-max-age").unwrap(), "3600");
|
|
382
|
+
assert_eq!(resp_headers.get("access-control-allow-credentials").unwrap(), "true");
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
#[test]
|
|
386
|
+
fn test_handle_preflight_origin_not_allowed() {
|
|
387
|
+
let config = make_cors_config();
|
|
388
|
+
let mut headers = HeaderMap::new();
|
|
389
|
+
headers.insert("origin", HeaderValue::from_static("https://evil.com"));
|
|
390
|
+
headers.insert("access-control-request-method", HeaderValue::from_static("GET"));
|
|
391
|
+
|
|
392
|
+
let result = handle_preflight(&headers, &config);
|
|
393
|
+
assert!(result.is_err());
|
|
394
|
+
|
|
395
|
+
let response = *result.unwrap_err();
|
|
396
|
+
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
#[test]
|
|
400
|
+
fn test_handle_preflight_method_not_allowed() {
|
|
401
|
+
let config = make_cors_config();
|
|
402
|
+
let mut headers = HeaderMap::new();
|
|
403
|
+
headers.insert("origin", HeaderValue::from_static("https://example.com"));
|
|
404
|
+
headers.insert("access-control-request-method", HeaderValue::from_static("DELETE"));
|
|
405
|
+
|
|
406
|
+
let result = handle_preflight(&headers, &config);
|
|
407
|
+
assert!(result.is_err());
|
|
408
|
+
|
|
409
|
+
let response = *result.unwrap_err();
|
|
410
|
+
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
#[test]
|
|
414
|
+
fn test_handle_preflight_header_not_allowed() {
|
|
415
|
+
let config = make_cors_config();
|
|
416
|
+
let mut headers = HeaderMap::new();
|
|
417
|
+
headers.insert("origin", HeaderValue::from_static("https://example.com"));
|
|
418
|
+
headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
|
|
419
|
+
headers.insert(
|
|
420
|
+
"access-control-request-headers",
|
|
421
|
+
HeaderValue::from_static("x-forbidden-header"),
|
|
422
|
+
);
|
|
423
|
+
|
|
424
|
+
let result = handle_preflight(&headers, &config);
|
|
425
|
+
assert!(result.is_err());
|
|
426
|
+
|
|
427
|
+
let response = *result.unwrap_err();
|
|
428
|
+
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
#[test]
|
|
432
|
+
fn test_handle_preflight_empty_origin() {
|
|
433
|
+
let config = make_cors_config();
|
|
434
|
+
let headers = HeaderMap::new();
|
|
435
|
+
|
|
436
|
+
let result = handle_preflight(&headers, &config);
|
|
437
|
+
assert!(result.is_err());
|
|
438
|
+
|
|
439
|
+
let response = *result.unwrap_err();
|
|
440
|
+
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
#[test]
|
|
444
|
+
fn test_add_cors_headers() {
|
|
445
|
+
let config = make_cors_config();
|
|
446
|
+
let mut response = Response::new(Body::empty());
|
|
447
|
+
|
|
448
|
+
add_cors_headers(&mut response, "https://example.com", &config);
|
|
449
|
+
|
|
450
|
+
let headers = response.headers();
|
|
451
|
+
assert_eq!(
|
|
452
|
+
headers.get("access-control-allow-origin").unwrap(),
|
|
453
|
+
"https://example.com"
|
|
454
|
+
);
|
|
455
|
+
assert_eq!(headers.get("access-control-expose-headers").unwrap(), "x-custom-header");
|
|
456
|
+
assert_eq!(headers.get("access-control-allow-credentials").unwrap(), "true");
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
#[test]
|
|
460
|
+
fn test_validate_cors_request_allowed() {
|
|
461
|
+
let config = make_cors_config();
|
|
462
|
+
let mut headers = HeaderMap::new();
|
|
463
|
+
headers.insert("origin", HeaderValue::from_static("https://example.com"));
|
|
464
|
+
|
|
465
|
+
let result = validate_cors_request(&headers, &config);
|
|
466
|
+
assert!(result.is_ok());
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
#[test]
|
|
470
|
+
fn test_validate_cors_request_not_allowed() {
|
|
471
|
+
let config = make_cors_config();
|
|
472
|
+
let mut headers = HeaderMap::new();
|
|
473
|
+
headers.insert("origin", HeaderValue::from_static("https://evil.com"));
|
|
474
|
+
|
|
475
|
+
let result = validate_cors_request(&headers, &config);
|
|
476
|
+
assert!(result.is_err());
|
|
477
|
+
|
|
478
|
+
let response = *result.unwrap_err();
|
|
479
|
+
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
#[test]
|
|
483
|
+
fn test_validate_cors_request_no_origin() {
|
|
484
|
+
let config = make_cors_config();
|
|
485
|
+
let headers = HeaderMap::new();
|
|
486
|
+
|
|
487
|
+
let result = validate_cors_request(&headers, &config);
|
|
488
|
+
assert!(result.is_ok());
|
|
489
|
+
}
|
|
490
|
+
|
|
491
|
+
// SECURITY TESTS: CORS Attack Vectors
|
|
492
|
+
|
|
493
|
+
/// SECURITY TEST: Verify credentials=true with wildcard is caught
|
|
494
|
+
/// This is a critical vulnerability - RFC 6454 forbids this
|
|
495
|
+
#[test]
|
|
496
|
+
fn test_credentials_with_wildcard_config_is_security_issue() {
|
|
497
|
+
let config = CorsConfig {
|
|
498
|
+
allowed_origins: vec!["*".to_string()],
|
|
499
|
+
allowed_methods: vec!["GET".to_string()],
|
|
500
|
+
allowed_headers: vec![],
|
|
501
|
+
expose_headers: None,
|
|
502
|
+
max_age: None,
|
|
503
|
+
allow_credentials: Some(true), // SECURITY BUG: This should not be allowed with wildcard
|
|
504
|
+
};
|
|
505
|
+
|
|
506
|
+
let mut headers = HeaderMap::new();
|
|
507
|
+
headers.insert("origin", HeaderValue::from_static("https://evil.com"));
|
|
508
|
+
headers.insert("access-control-request-method", HeaderValue::from_static("GET"));
|
|
509
|
+
|
|
510
|
+
let result = handle_preflight(&headers, &config);
|
|
511
|
+
|
|
512
|
+
// BUG: This should return 500 or reject the config, but instead succeeds
|
|
513
|
+
if let Ok(response) = result {
|
|
514
|
+
let resp_headers = response.headers();
|
|
515
|
+
let has_credentials = resp_headers
|
|
516
|
+
.get("access-control-allow-credentials")
|
|
517
|
+
.map(|v| v == "true")
|
|
518
|
+
.unwrap_or(false);
|
|
519
|
+
let origin_header = resp_headers.get("access-control-allow-origin");
|
|
520
|
+
|
|
521
|
+
if has_credentials && origin_header.is_some() {
|
|
522
|
+
let origin_val = origin_header.unwrap().to_str().unwrap_or("");
|
|
523
|
+
if origin_val == "*" {
|
|
524
|
+
panic!("SECURITY VULNERABILITY: credentials=true with origin=* allowed");
|
|
525
|
+
}
|
|
526
|
+
}
|
|
527
|
+
}
|
|
528
|
+
}
|
|
529
|
+
|
|
530
|
+
/// SECURITY TEST: Exact origin matching required
|
|
531
|
+
/// Subdomain like api.evil.example.com must NOT match example.com
|
|
532
|
+
#[test]
|
|
533
|
+
fn test_subdomain_bypass_blocked() {
|
|
534
|
+
let config = CorsConfig {
|
|
535
|
+
allowed_origins: vec!["https://example.com".to_string()],
|
|
536
|
+
allowed_methods: vec!["GET".to_string()],
|
|
537
|
+
allowed_headers: vec![],
|
|
538
|
+
expose_headers: None,
|
|
539
|
+
max_age: None,
|
|
540
|
+
allow_credentials: None,
|
|
541
|
+
};
|
|
542
|
+
|
|
543
|
+
assert!(!is_origin_allowed("https://api.example.com", &config.allowed_origins));
|
|
544
|
+
assert!(!is_origin_allowed("https://evil.example.com", &config.allowed_origins));
|
|
545
|
+
assert!(!is_origin_allowed(
|
|
546
|
+
"https://sub.sub.example.com",
|
|
547
|
+
&config.allowed_origins
|
|
548
|
+
));
|
|
549
|
+
|
|
550
|
+
assert!(is_origin_allowed("https://example.com", &config.allowed_origins));
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
/// SECURITY TEST: Port exact matching required
|
|
554
|
+
/// localhost:3001 must NOT match localhost:3000
|
|
555
|
+
#[test]
|
|
556
|
+
fn test_port_bypass_blocked() {
|
|
557
|
+
let config = CorsConfig {
|
|
558
|
+
allowed_origins: vec!["http://localhost:3000".to_string()],
|
|
559
|
+
allowed_methods: vec!["GET".to_string()],
|
|
560
|
+
allowed_headers: vec![],
|
|
561
|
+
expose_headers: None,
|
|
562
|
+
max_age: None,
|
|
563
|
+
allow_credentials: None,
|
|
564
|
+
};
|
|
565
|
+
|
|
566
|
+
assert!(!is_origin_allowed("http://localhost:3001", &config.allowed_origins));
|
|
567
|
+
assert!(!is_origin_allowed("http://localhost:8080", &config.allowed_origins));
|
|
568
|
+
assert!(!is_origin_allowed("http://localhost:443", &config.allowed_origins));
|
|
569
|
+
|
|
570
|
+
assert!(is_origin_allowed("http://localhost:3000", &config.allowed_origins));
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
/// SECURITY TEST: Protocol exact matching required
|
|
574
|
+
/// http://example.com must NOT match https://example.com
|
|
575
|
+
#[test]
|
|
576
|
+
fn test_protocol_downgrade_attack_blocked() {
|
|
577
|
+
let config = CorsConfig {
|
|
578
|
+
allowed_origins: vec!["https://example.com".to_string()],
|
|
579
|
+
allowed_methods: vec!["GET".to_string()],
|
|
580
|
+
allowed_headers: vec![],
|
|
581
|
+
expose_headers: None,
|
|
582
|
+
max_age: None,
|
|
583
|
+
allow_credentials: None,
|
|
584
|
+
};
|
|
585
|
+
|
|
586
|
+
assert!(!is_origin_allowed("http://example.com", &config.allowed_origins));
|
|
587
|
+
assert!(!is_origin_allowed("ws://example.com", &config.allowed_origins));
|
|
588
|
+
assert!(!is_origin_allowed("wss://example.com", &config.allowed_origins));
|
|
589
|
+
|
|
590
|
+
assert!(is_origin_allowed("https://example.com", &config.allowed_origins));
|
|
591
|
+
}
|
|
592
|
+
|
|
593
|
+
/// SECURITY TEST: Case sensitivity in origin matching
|
|
594
|
+
/// Origins should match exactly (including case)
|
|
595
|
+
#[test]
|
|
596
|
+
fn test_case_sensitive_origin_matching() {
|
|
597
|
+
let config = CorsConfig {
|
|
598
|
+
allowed_origins: vec!["https://Example.Com".to_string()],
|
|
599
|
+
allowed_methods: vec!["GET".to_string()],
|
|
600
|
+
allowed_headers: vec![],
|
|
601
|
+
expose_headers: None,
|
|
602
|
+
max_age: None,
|
|
603
|
+
allow_credentials: None,
|
|
604
|
+
};
|
|
605
|
+
|
|
606
|
+
assert!(!is_origin_allowed("https://example.com", &config.allowed_origins));
|
|
607
|
+
assert!(!is_origin_allowed("https://EXAMPLE.COM", &config.allowed_origins));
|
|
608
|
+
|
|
609
|
+
assert!(is_origin_allowed("https://Example.Com", &config.allowed_origins));
|
|
610
|
+
}
|
|
611
|
+
|
|
612
|
+
/// SECURITY TEST: Trailing slash normalization
|
|
613
|
+
/// https://example.com/ should be treated differently from https://example.com
|
|
614
|
+
#[test]
|
|
615
|
+
fn test_trailing_slash_origin_not_normalized() {
|
|
616
|
+
let config = CorsConfig {
|
|
617
|
+
allowed_origins: vec!["https://example.com".to_string()],
|
|
618
|
+
allowed_methods: vec!["GET".to_string()],
|
|
619
|
+
allowed_headers: vec![],
|
|
620
|
+
expose_headers: None,
|
|
621
|
+
max_age: None,
|
|
622
|
+
allow_credentials: None,
|
|
623
|
+
};
|
|
624
|
+
|
|
625
|
+
assert!(!is_origin_allowed("https://example.com/", &config.allowed_origins));
|
|
626
|
+
|
|
627
|
+
assert!(is_origin_allowed("https://example.com", &config.allowed_origins));
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
/// SECURITY TEST: NULL origin and wildcard behavior
|
|
631
|
+
/// Special "null" origin used by file:// and sandboxed iframes
|
|
632
|
+
/// The current implementation treats "null" as a regular origin string,
|
|
633
|
+
/// which means it IS allowed by wildcard (not ideal but documents current behavior)
|
|
634
|
+
#[test]
|
|
635
|
+
fn test_null_origin_with_wildcard() {
|
|
636
|
+
let config = CorsConfig {
|
|
637
|
+
allowed_origins: vec!["*".to_string()],
|
|
638
|
+
allowed_methods: vec!["GET".to_string()],
|
|
639
|
+
allowed_headers: vec![],
|
|
640
|
+
expose_headers: None,
|
|
641
|
+
max_age: None,
|
|
642
|
+
allow_credentials: None,
|
|
643
|
+
};
|
|
644
|
+
|
|
645
|
+
// SECURITY NOTE: "null" origin is allowed by wildcard in current implementation
|
|
646
|
+
assert!(is_origin_allowed("null", &config.allowed_origins));
|
|
647
|
+
|
|
648
|
+
let with_explicit_null = CorsConfig {
|
|
649
|
+
allowed_origins: vec!["null".to_string()],
|
|
650
|
+
allowed_methods: vec!["GET".to_string()],
|
|
651
|
+
allowed_headers: vec![],
|
|
652
|
+
expose_headers: None,
|
|
653
|
+
max_age: None,
|
|
654
|
+
allow_credentials: None,
|
|
655
|
+
};
|
|
656
|
+
assert!(is_origin_allowed("null", &with_explicit_null.allowed_origins));
|
|
657
|
+
}
|
|
658
|
+
|
|
659
|
+
/// SECURITY TEST: Empty origin is always rejected
|
|
660
|
+
#[test]
|
|
661
|
+
fn test_empty_origin_always_rejected() {
|
|
662
|
+
let config_with_wildcard = CorsConfig {
|
|
663
|
+
allowed_origins: vec!["*".to_string()],
|
|
664
|
+
allowed_methods: vec!["GET".to_string()],
|
|
665
|
+
allowed_headers: vec![],
|
|
666
|
+
expose_headers: None,
|
|
667
|
+
max_age: None,
|
|
668
|
+
allow_credentials: None,
|
|
669
|
+
};
|
|
670
|
+
assert!(!is_origin_allowed("", &config_with_wildcard.allowed_origins));
|
|
671
|
+
|
|
672
|
+
let config_with_explicit = CorsConfig {
|
|
673
|
+
allowed_origins: vec!["https://example.com".to_string()],
|
|
674
|
+
allowed_methods: vec!["GET".to_string()],
|
|
675
|
+
allowed_headers: vec![],
|
|
676
|
+
expose_headers: None,
|
|
677
|
+
max_age: None,
|
|
678
|
+
allow_credentials: None,
|
|
679
|
+
};
|
|
680
|
+
assert!(!is_origin_allowed("", &config_with_explicit.allowed_origins));
|
|
681
|
+
}
|
|
682
|
+
|
|
683
|
+
/// SECURITY TEST: Preflight with invalid origin should reject
|
|
684
|
+
#[test]
|
|
685
|
+
fn test_preflight_rejects_invalid_origin() {
|
|
686
|
+
let config = make_cors_config();
|
|
687
|
+
let mut headers = HeaderMap::new();
|
|
688
|
+
headers.insert("origin", HeaderValue::from_static("https://untrusted.com"));
|
|
689
|
+
headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
|
|
690
|
+
|
|
691
|
+
let result = handle_preflight(&headers, &config);
|
|
692
|
+
assert!(result.is_err());
|
|
693
|
+
|
|
694
|
+
let response = *result.unwrap_err();
|
|
695
|
+
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
|
696
|
+
}
|
|
697
|
+
|
|
698
|
+
/// SECURITY TEST: Multiple origins - each must be exact match
|
|
699
|
+
#[test]
|
|
700
|
+
fn test_multiple_origins_exact_matching() {
|
|
701
|
+
let config = CorsConfig {
|
|
702
|
+
allowed_origins: vec!["https://trusted1.com".to_string(), "https://trusted2.com".to_string()],
|
|
703
|
+
allowed_methods: vec!["GET".to_string()],
|
|
704
|
+
allowed_headers: vec![],
|
|
705
|
+
expose_headers: None,
|
|
706
|
+
max_age: None,
|
|
707
|
+
allow_credentials: None,
|
|
708
|
+
};
|
|
709
|
+
|
|
710
|
+
assert!(is_origin_allowed("https://trusted1.com", &config.allowed_origins));
|
|
711
|
+
assert!(is_origin_allowed("https://trusted2.com", &config.allowed_origins));
|
|
712
|
+
|
|
713
|
+
assert!(!is_origin_allowed(
|
|
714
|
+
"https://trusted1.com.evil.com",
|
|
715
|
+
&config.allowed_origins
|
|
716
|
+
));
|
|
717
|
+
assert!(!is_origin_allowed("https://trusted3.com", &config.allowed_origins));
|
|
718
|
+
assert!(!is_origin_allowed("https://trusted.com", &config.allowed_origins));
|
|
719
|
+
}
|
|
720
|
+
|
|
721
|
+
/// SECURITY TEST: Wildcard origin should allow any origin (but check config)
|
|
722
|
+
#[test]
|
|
723
|
+
fn test_wildcard_allows_all_but_check_credentials() {
|
|
724
|
+
let config = CorsConfig {
|
|
725
|
+
allowed_origins: vec!["*".to_string()],
|
|
726
|
+
allowed_methods: vec!["GET".to_string()],
|
|
727
|
+
allowed_headers: vec![],
|
|
728
|
+
expose_headers: None,
|
|
729
|
+
max_age: None,
|
|
730
|
+
allow_credentials: None,
|
|
731
|
+
};
|
|
732
|
+
|
|
733
|
+
assert!(is_origin_allowed("https://example.com", &config.allowed_origins));
|
|
734
|
+
assert!(is_origin_allowed("https://evil.com", &config.allowed_origins));
|
|
735
|
+
assert!(is_origin_allowed("http://localhost:3000", &config.allowed_origins));
|
|
736
|
+
|
|
737
|
+
assert!(!is_origin_allowed("", &config.allowed_origins));
|
|
738
|
+
}
|
|
739
|
+
|
|
740
|
+
/// SECURITY TEST: Preflight response headers must match config exactly
|
|
741
|
+
#[test]
|
|
742
|
+
fn test_preflight_response_has_correct_allowed_origins() {
|
|
743
|
+
let config = CorsConfig {
|
|
744
|
+
allowed_origins: vec!["https://trusted.com".to_string()],
|
|
745
|
+
allowed_methods: vec!["GET".to_string(), "POST".to_string()],
|
|
746
|
+
allowed_headers: vec!["content-type".to_string()],
|
|
747
|
+
expose_headers: None,
|
|
748
|
+
max_age: Some(3600),
|
|
749
|
+
allow_credentials: Some(false),
|
|
750
|
+
};
|
|
751
|
+
|
|
752
|
+
let mut headers = HeaderMap::new();
|
|
753
|
+
headers.insert("origin", HeaderValue::from_static("https://trusted.com"));
|
|
754
|
+
headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
|
|
755
|
+
headers.insert(
|
|
756
|
+
"access-control-request-headers",
|
|
757
|
+
HeaderValue::from_static("content-type"),
|
|
758
|
+
);
|
|
759
|
+
|
|
760
|
+
let result = handle_preflight(&headers, &config);
|
|
761
|
+
assert!(result.is_ok());
|
|
762
|
+
|
|
763
|
+
let response = result.unwrap();
|
|
764
|
+
let resp_headers = response.headers();
|
|
765
|
+
|
|
766
|
+
assert_eq!(
|
|
767
|
+
resp_headers.get("access-control-allow-origin").unwrap(),
|
|
768
|
+
"https://trusted.com"
|
|
769
|
+
);
|
|
770
|
+
|
|
771
|
+
assert!(
|
|
772
|
+
resp_headers
|
|
773
|
+
.get("access-control-allow-methods")
|
|
774
|
+
.unwrap()
|
|
775
|
+
.to_str()
|
|
776
|
+
.unwrap()
|
|
777
|
+
.contains("GET")
|
|
778
|
+
);
|
|
779
|
+
assert!(
|
|
780
|
+
resp_headers
|
|
781
|
+
.get("access-control-allow-methods")
|
|
782
|
+
.unwrap()
|
|
783
|
+
.to_str()
|
|
784
|
+
.unwrap()
|
|
785
|
+
.contains("POST")
|
|
786
|
+
);
|
|
787
|
+
|
|
788
|
+
assert!(resp_headers.get("access-control-allow-credentials").is_none());
|
|
789
|
+
}
|
|
790
|
+
|
|
791
|
+
/// SECURITY TEST: Origin not in allowed list must be rejected in preflight
|
|
792
|
+
#[test]
|
|
793
|
+
fn test_preflight_all_origins_require_validation() {
|
|
794
|
+
let config = CorsConfig {
|
|
795
|
+
allowed_origins: vec!["https://trusted.com".to_string()],
|
|
796
|
+
allowed_methods: vec!["GET".to_string()],
|
|
797
|
+
allowed_headers: vec![],
|
|
798
|
+
expose_headers: None,
|
|
799
|
+
max_age: None,
|
|
800
|
+
allow_credentials: None,
|
|
801
|
+
};
|
|
802
|
+
|
|
803
|
+
let test_cases = vec![
|
|
804
|
+
"https://trusted.com",
|
|
805
|
+
"https://evil.com",
|
|
806
|
+
"https://trusted.com.evil",
|
|
807
|
+
"http://trusted.com",
|
|
808
|
+
"",
|
|
809
|
+
];
|
|
810
|
+
|
|
811
|
+
for origin in test_cases {
|
|
812
|
+
let mut headers = HeaderMap::new();
|
|
813
|
+
headers.insert(
|
|
814
|
+
"origin",
|
|
815
|
+
HeaderValue::from_str(origin).unwrap_or_else(|_| HeaderValue::from_static("")),
|
|
816
|
+
);
|
|
817
|
+
headers.insert("access-control-request-method", HeaderValue::from_static("GET"));
|
|
818
|
+
|
|
819
|
+
let result = handle_preflight(&headers, &config);
|
|
820
|
+
|
|
821
|
+
if origin == "https://trusted.com" {
|
|
822
|
+
assert!(result.is_ok(), "Valid origin {} should be allowed", origin);
|
|
823
|
+
} else {
|
|
824
|
+
assert!(result.is_err(), "Invalid origin {} should be rejected", origin);
|
|
825
|
+
}
|
|
826
|
+
}
|
|
827
|
+
}
|
|
828
|
+
|
|
829
|
+
/// SECURITY TEST: Requested headers must be in allowed list
|
|
830
|
+
#[test]
|
|
831
|
+
fn test_preflight_validates_all_requested_headers() {
|
|
832
|
+
let config = CorsConfig {
|
|
833
|
+
allowed_origins: vec!["https://trusted.com".to_string()],
|
|
834
|
+
allowed_methods: vec!["POST".to_string()],
|
|
835
|
+
allowed_headers: vec!["content-type".to_string(), "authorization".to_string()],
|
|
836
|
+
expose_headers: None,
|
|
837
|
+
max_age: None,
|
|
838
|
+
allow_credentials: None,
|
|
839
|
+
};
|
|
840
|
+
|
|
841
|
+
let test_cases = vec![
|
|
842
|
+
("content-type", true),
|
|
843
|
+
("authorization", true),
|
|
844
|
+
("content-type, authorization", true),
|
|
845
|
+
("x-custom-header", false),
|
|
846
|
+
("content-type, x-custom", false),
|
|
847
|
+
];
|
|
848
|
+
|
|
849
|
+
for (headers_str, should_pass) in test_cases {
|
|
850
|
+
let mut headers = HeaderMap::new();
|
|
851
|
+
headers.insert("origin", HeaderValue::from_static("https://trusted.com"));
|
|
852
|
+
headers.insert("access-control-request-method", HeaderValue::from_static("POST"));
|
|
853
|
+
headers.insert(
|
|
854
|
+
"access-control-request-headers",
|
|
855
|
+
HeaderValue::from_str(headers_str).unwrap(),
|
|
856
|
+
);
|
|
857
|
+
|
|
858
|
+
let result = handle_preflight(&headers, &config);
|
|
859
|
+
|
|
860
|
+
if should_pass {
|
|
861
|
+
assert!(
|
|
862
|
+
result.is_ok(),
|
|
863
|
+
"Preflight with valid headers '{}' should pass",
|
|
864
|
+
headers_str
|
|
865
|
+
);
|
|
866
|
+
} else {
|
|
867
|
+
assert!(
|
|
868
|
+
result.is_err(),
|
|
869
|
+
"Preflight with invalid headers '{}' should fail",
|
|
870
|
+
headers_str
|
|
871
|
+
);
|
|
872
|
+
}
|
|
873
|
+
}
|
|
874
|
+
}
|
|
875
|
+
|
|
876
|
+
/// SECURITY TEST: add_cors_headers should respect origin validation
|
|
877
|
+
#[test]
|
|
878
|
+
fn test_add_cors_headers_respects_origin() {
|
|
879
|
+
let config = CorsConfig {
|
|
880
|
+
allowed_origins: vec!["https://trusted.com".to_string()],
|
|
881
|
+
allowed_methods: vec!["GET".to_string()],
|
|
882
|
+
allowed_headers: vec![],
|
|
883
|
+
expose_headers: Some(vec!["x-custom".to_string()]),
|
|
884
|
+
max_age: None,
|
|
885
|
+
allow_credentials: Some(true),
|
|
886
|
+
};
|
|
887
|
+
|
|
888
|
+
let mut response = Response::new(Body::empty());
|
|
889
|
+
|
|
890
|
+
add_cors_headers(&mut response, "https://trusted.com", &config);
|
|
891
|
+
|
|
892
|
+
let headers = response.headers();
|
|
893
|
+
assert_eq!(
|
|
894
|
+
headers.get("access-control-allow-origin").unwrap(),
|
|
895
|
+
"https://trusted.com"
|
|
896
|
+
);
|
|
897
|
+
assert_eq!(headers.get("access-control-expose-headers").unwrap(), "x-custom");
|
|
898
|
+
assert_eq!(headers.get("access-control-allow-credentials").unwrap(), "true");
|
|
899
|
+
}
|
|
900
|
+
|
|
901
|
+
/// SECURITY TEST: validate_cors_request respects allowed origins
|
|
902
|
+
#[test]
|
|
903
|
+
fn test_validate_cors_request_origin_must_match() {
|
|
904
|
+
let config = CorsConfig {
|
|
905
|
+
allowed_origins: vec!["https://trusted.com".to_string()],
|
|
906
|
+
allowed_methods: vec!["GET".to_string()],
|
|
907
|
+
allowed_headers: vec![],
|
|
908
|
+
expose_headers: None,
|
|
909
|
+
max_age: None,
|
|
910
|
+
allow_credentials: None,
|
|
911
|
+
};
|
|
912
|
+
|
|
913
|
+
let mut headers = HeaderMap::new();
|
|
914
|
+
headers.insert("origin", HeaderValue::from_static("https://trusted.com"));
|
|
915
|
+
assert!(validate_cors_request(&headers, &config).is_ok());
|
|
916
|
+
|
|
917
|
+
let mut headers = HeaderMap::new();
|
|
918
|
+
headers.insert("origin", HeaderValue::from_static("https://evil.com"));
|
|
919
|
+
assert!(validate_cors_request(&headers, &config).is_err());
|
|
920
|
+
|
|
921
|
+
let headers = HeaderMap::new();
|
|
922
|
+
assert!(validate_cors_request(&headers, &config).is_ok());
|
|
923
|
+
}
|
|
924
|
+
|
|
925
|
+
/// SECURITY TEST: Preflight without requested method should fail
|
|
926
|
+
#[test]
|
|
927
|
+
fn test_preflight_requires_access_control_request_method() {
|
|
928
|
+
let config = make_cors_config();
|
|
929
|
+
let mut headers = HeaderMap::new();
|
|
930
|
+
headers.insert("origin", HeaderValue::from_static("https://example.com"));
|
|
931
|
+
|
|
932
|
+
let result = handle_preflight(&headers, &config);
|
|
933
|
+
assert!(result.is_ok());
|
|
934
|
+
}
|
|
935
|
+
|
|
936
|
+
/// SECURITY TEST: Case-insensitive method matching
|
|
937
|
+
#[test]
|
|
938
|
+
fn test_preflight_method_case_insensitive() {
|
|
939
|
+
let config = CorsConfig {
|
|
940
|
+
allowed_origins: vec!["https://example.com".to_string()],
|
|
941
|
+
allowed_methods: vec!["GET".to_string(), "POST".to_string()],
|
|
942
|
+
allowed_headers: vec![],
|
|
943
|
+
expose_headers: None,
|
|
944
|
+
max_age: None,
|
|
945
|
+
allow_credentials: None,
|
|
946
|
+
};
|
|
947
|
+
|
|
948
|
+
let test_cases = vec!["GET", "get", "Get", "POST", "post"];
|
|
949
|
+
|
|
950
|
+
for method in test_cases {
|
|
951
|
+
let mut headers = HeaderMap::new();
|
|
952
|
+
headers.insert("origin", HeaderValue::from_static("https://example.com"));
|
|
953
|
+
headers.insert("access-control-request-method", HeaderValue::from_str(method).unwrap());
|
|
954
|
+
|
|
955
|
+
let result = handle_preflight(&headers, &config);
|
|
956
|
+
assert!(
|
|
957
|
+
result.is_ok(),
|
|
958
|
+
"Method '{}' should be allowed (case-insensitive)",
|
|
959
|
+
method
|
|
960
|
+
);
|
|
961
|
+
}
|
|
962
|
+
}
|
|
963
|
+
|
|
964
|
+
/// SECURITY TEST: Ensure preflight max-age is set correctly
|
|
965
|
+
#[test]
|
|
966
|
+
fn test_preflight_max_age_header() {
|
|
967
|
+
let config = CorsConfig {
|
|
968
|
+
allowed_origins: vec!["https://example.com".to_string()],
|
|
969
|
+
allowed_methods: vec!["GET".to_string()],
|
|
970
|
+
allowed_headers: vec![],
|
|
971
|
+
expose_headers: None,
|
|
972
|
+
max_age: Some(7200),
|
|
973
|
+
allow_credentials: None,
|
|
974
|
+
};
|
|
975
|
+
|
|
976
|
+
let mut headers = HeaderMap::new();
|
|
977
|
+
headers.insert("origin", HeaderValue::from_static("https://example.com"));
|
|
978
|
+
headers.insert("access-control-request-method", HeaderValue::from_static("GET"));
|
|
979
|
+
|
|
980
|
+
let result = handle_preflight(&headers, &config);
|
|
981
|
+
assert!(result.is_ok());
|
|
982
|
+
|
|
983
|
+
let response = result.unwrap();
|
|
984
|
+
assert_eq!(response.headers().get("access-control-max-age").unwrap(), "7200");
|
|
985
|
+
}
|
|
986
|
+
|
|
987
|
+
/// SECURITY TEST: Wildcard partial patterns should not work
|
|
988
|
+
/// *.example.com style patterns are not supported (good!)
|
|
989
|
+
#[test]
|
|
990
|
+
fn test_wildcard_patterns_not_supported() {
|
|
991
|
+
let config = CorsConfig {
|
|
992
|
+
allowed_origins: vec!["*.example.com".to_string()],
|
|
993
|
+
allowed_methods: vec!["GET".to_string()],
|
|
994
|
+
allowed_headers: vec![],
|
|
995
|
+
expose_headers: None,
|
|
996
|
+
max_age: None,
|
|
997
|
+
allow_credentials: None,
|
|
998
|
+
};
|
|
999
|
+
|
|
1000
|
+
assert!(!is_origin_allowed("https://api.example.com", &config.allowed_origins));
|
|
1001
|
+
assert!(!is_origin_allowed("https://example.com", &config.allowed_origins));
|
|
1002
|
+
|
|
1003
|
+
assert!(is_origin_allowed("*.example.com", &config.allowed_origins));
|
|
1004
|
+
}
|
|
1005
|
+
}
|