@nmtjs/proxy 0.15.0-beta.1 → 0.15.0-beta.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/router.rs CHANGED
@@ -1,406 +1,513 @@
1
- use crate::config::{AppUpstream, ProxyConfig, UpstreamKind};
2
- use async_trait::async_trait;
3
- use http::{Uri, uri::PathAndQuery};
4
- use log::{debug, info};
5
- use napi::Error as NapiError;
6
- use napi::bindgen_prelude::Result as NapiResult;
7
- use pingora::{
8
- http::ResponseHeader,
9
- lb::{Backend, Backends, discovery::Static},
10
- prelude::*,
11
- protocols::l4::socket::SocketAddr as PingoraSocketAddr,
12
- services::background::{GenBackgroundService, background_service},
13
- };
14
- use pingora_load_balancing::health_check::TcpHealthCheck;
15
- use std::{
16
- collections::{BTreeSet, HashMap},
17
- net::ToSocketAddrs,
18
- sync::Arc,
19
- time::Duration,
20
- };
21
-
22
- pub type Cluster = LoadBalancer<RoundRobin>;
23
-
24
- pub struct ClusterEntry {
25
- pub balancer: Arc<Cluster>,
26
- pub sni: Option<String>,
1
+ use std::borrow::Cow;
2
+ use std::collections::HashMap;
3
+ use std::sync::Arc;
4
+
5
+ use arc_swap::ArcSwap;
6
+ use http::header;
7
+ use http::{StatusCode, Uri};
8
+ use pingora::http::RequestHeader;
9
+ use pingora::http::ResponseHeader;
10
+ use pingora::lb::{LoadBalancer, selection::RoundRobin};
11
+ use pingora::modules::http::HttpModules;
12
+ use pingora::proxy::{ProxyHttp, Session};
13
+ use pingora::upstreams::peer::HttpPeer;
14
+ use pingora::{Error, ErrorType, Result};
15
+
16
+ #[derive(Clone, Default)]
17
+ pub struct RouterConfig {
18
+ pub subdomain_routes: HashMap<String, String>,
19
+ pub path_routes: HashMap<String, String>,
20
+ pub default_app: Option<String>,
21
+ pub apps: HashMap<String, AppPools>,
27
22
  }
28
23
 
29
- pub struct Router {
30
- clusters: HashMap<String, HashMap<UpstreamKind, ClusterEntry>>,
31
- upstreams_tls: HashMap<String, bool>,
24
+ #[derive(Clone)]
25
+ pub struct AppPools {
26
+ pub http1: Option<PoolConfig>,
27
+ pub http2: Option<PoolConfig>,
32
28
  }
33
29
 
34
- #[derive(Default)]
35
- pub struct RouterCtx {
36
- app_info: AppInfoState,
30
+ #[derive(Clone)]
31
+ pub struct PoolConfig {
32
+ pub lb: Arc<LoadBalancer<RoundRobin>>,
33
+ pub secure: bool,
34
+ pub verify_hostname: String,
37
35
  }
38
36
 
39
- struct AppInfo {
40
- name: String,
41
- kind: UpstreamKind,
37
+ /// Pre-resolved pool information cached in context to avoid repeated lookups.
38
+ #[derive(Clone)]
39
+ pub struct ResolvedPool {
40
+ pub lb: Arc<LoadBalancer<RoundRobin>>,
41
+ pub secure: bool,
42
+ pub verify_hostname: String,
43
+ pub is_http2: bool,
42
44
  }
43
45
 
44
- #[derive(Default)]
45
- enum AppInfoState {
46
- #[default]
47
- Unknown,
48
- Missing,
49
- Found(AppInfo),
46
+ #[allow(dead_code)]
47
+ pub struct Router {
48
+ config: ArcSwap<RouterConfig>,
50
49
  }
51
50
 
52
- impl RouterCtx {
53
- fn app_info(&mut self, session: &Session) -> Option<&AppInfo> {
54
- if matches!(self.app_info, AppInfoState::Unknown) {
55
- self.app_info = match extract_app_name(session) {
56
- Some(name) => AppInfoState::Found(AppInfo {
57
- name: name.to_string(),
58
- kind: if session.is_upgrade_req() {
59
- UpstreamKind::Websocket
60
- } else {
61
- UpstreamKind::Http
62
- },
63
- }),
64
- None => AppInfoState::Missing,
65
- };
51
+ impl Router {
52
+ #[allow(dead_code)]
53
+ pub fn new(config: RouterConfig) -> Self {
54
+ Self {
55
+ config: ArcSwap::from_pointee(config),
66
56
  }
57
+ }
67
58
 
68
- match &self.app_info {
69
- AppInfoState::Found(info) => Some(info),
70
- _ => None,
71
- }
59
+ #[allow(dead_code)]
60
+ pub fn update(&self, config: RouterConfig) {
61
+ self.config.store(Arc::new(config));
72
62
  }
73
63
  }
74
64
 
75
- impl Router {
76
- fn new(
77
- clusters: HashMap<String, HashMap<UpstreamKind, ClusterEntry>>,
78
- upstreams_tls: HashMap<String, bool>,
79
- ) -> Self {
80
- Self {
81
- clusters,
82
- upstreams_tls,
83
- }
84
- }
65
+ #[derive(Clone)]
66
+ pub struct SharedRouter(pub Arc<Router>);
67
+
68
+ #[derive(Clone, Default)]
69
+ pub struct RouterCtx {
70
+ pub app_name: Option<String>,
71
+ pub path_rewrite_segment: Option<String>,
72
+ pub is_upgrade: bool,
73
+ /// Cached pool resolution from request_filter to avoid re-lookup in upstream_peer.
74
+ pub resolved_pool: Option<ResolvedPool>,
75
+ }
85
76
 
86
- fn cluster_for<'a>(&'a self, app_info: &AppInfo) -> Option<&'a ClusterEntry> {
87
- self.clusters
88
- .get(&app_info.name)
89
- .and_then(|entries| entries.get(&app_info.kind))
77
+ impl SharedRouter {
78
+ pub fn new(router: Arc<Router>) -> Self {
79
+ Self(router)
90
80
  }
91
81
  }
92
82
 
93
- #[async_trait]
94
- impl ProxyHttp for Router {
83
+ #[async_trait::async_trait]
84
+ impl ProxyHttp for SharedRouter {
95
85
  type CTX = RouterCtx;
96
86
 
97
87
  fn new_ctx(&self) -> Self::CTX {
98
88
  RouterCtx::default()
99
89
  }
100
90
 
91
+ fn init_downstream_modules(&self, modules: &mut HttpModules) {
92
+ // Keep Pingora's default behavior (disabled compression) explicit here so we
93
+ // have a clear extension point for adding static downstream modules later.
94
+ modules
95
+ .add_module(pingora::modules::http::compression::ResponseCompressionBuilder::enable(0));
96
+ }
97
+
98
+ async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<bool> {
99
+ let config = self.0.config.load();
100
+
101
+ // TODO(vNext): Deterministic downstream error mapping.
102
+ // Today, a number of routing/upstream-selection failures bubble up as Pingora internal errors,
103
+ // which typically become HTTP 500 responses but without a fully controlled body/headers.
104
+ // Decide and implement a single, explicit downstream error policy for at least:
105
+ // - no application matched (no subdomain/path/default)
106
+ // - matched app has no pools configured
107
+ // - no pools available for request type (e.g. upgrade requires http1)
108
+ // - no healthy upstreams available
109
+ // Acceptance: response status/body/headers are stable across versions and covered by tests.
110
+
111
+ let host = extract_host(session);
112
+ let path_first_segment = extract_first_path_segment(session);
113
+ let is_upgrade = is_upgrade_request(session);
114
+
115
+ let mut app_name: Option<String> = None;
116
+ let mut rewrite_segment: Option<String> = None;
117
+
118
+ if let Some(host) = host.as_deref() {
119
+ app_name = config.subdomain_routes.get(host).cloned();
120
+ }
121
+
122
+ if app_name.is_none()
123
+ && let Some(seg) = path_first_segment
124
+ && let Some(app) = config.path_routes.get(seg).cloned()
125
+ {
126
+ app_name = Some(app);
127
+ rewrite_segment = Some(seg.to_string());
128
+ }
129
+
130
+ if app_name.is_none() {
131
+ app_name = config.default_app.clone();
132
+ }
133
+
134
+ ctx.app_name = app_name;
135
+ ctx.path_rewrite_segment = rewrite_segment;
136
+ ctx.is_upgrade = is_upgrade;
137
+
138
+ // Pre-resolve the pool to avoid repeated HashMap lookups in upstream_peer.
139
+ if let Some(ref app_name) = ctx.app_name
140
+ && let Some(pools) = config.apps.get(app_name)
141
+ {
142
+ let (pool, is_http2) = if is_upgrade {
143
+ (pools.http1.as_ref(), false)
144
+ } else if let Some(p) = pools.http2.as_ref() {
145
+ (Some(p), true)
146
+ } else {
147
+ (pools.http1.as_ref(), false)
148
+ };
149
+
150
+ if let Some(pool) = pool {
151
+ ctx.resolved_pool = Some(ResolvedPool {
152
+ lb: Arc::clone(&pool.lb),
153
+ secure: pool.secure,
154
+ verify_hostname: pool.verify_hostname.clone(),
155
+ is_http2,
156
+ });
157
+ }
158
+ }
159
+
160
+ // Deterministic behavior: Upgrade/WebSocket must go to an HTTP/1 pool.
161
+ // If no HTTP/1 pool exists for the matched app, respond with a consistent error.
162
+ if ctx.is_upgrade
163
+ && let Some(app_name) = ctx.app_name.as_deref()
164
+ && let Some(pools) = config.apps.get(app_name)
165
+ && pools.http1.is_none()
166
+ {
167
+ let mut resp = ResponseHeader::build(StatusCode::INTERNAL_SERVER_ERROR, Some(2))?;
168
+ let _ = resp.insert_header(header::CONTENT_LENGTH, 0);
169
+ session.write_response_header(Box::new(resp), true).await?;
170
+ return Ok(true);
171
+ }
172
+
173
+ Ok(false)
174
+ }
175
+
101
176
  async fn upstream_peer(
102
177
  &self,
103
- session: &mut Session,
104
- ctx: &mut RouterCtx,
178
+ _session: &mut Session,
179
+ ctx: &mut Self::CTX,
105
180
  ) -> Result<Box<HttpPeer>> {
106
- const NO_CLUSTER: ImmutStr = ImmutStr::Static("no matching application for request");
107
-
108
- let app_info = ctx.app_info(session).ok_or_else(|| {
109
- Error::create(
110
- ErrorType::ConnectError,
111
- ErrorSource::Internal,
112
- Some(NO_CLUSTER),
113
- None,
114
- )
115
- })?;
116
- let cluster = self.cluster_for(app_info).ok_or_else(|| {
117
- Error::create(
118
- ErrorType::ConnectError,
119
- ErrorSource::Internal,
120
- Some(NO_CLUSTER),
121
- None,
122
- )
123
- })?;
124
- const NO_UPSTREAM: ImmutStr = ImmutStr::Static("no available upstream for application");
125
- let upstream = cluster.balancer.select(b"", 256).ok_or_else(|| {
126
- Error::create(
127
- ErrorType::ConnectError,
128
- ErrorSource::Internal,
129
- Some(NO_UPSTREAM),
130
- None,
131
- )
132
- })?;
181
+ // Use pre-resolved pool from request_filter when available.
182
+ let Some(resolved) = ctx.resolved_pool.as_ref() else {
183
+ // Fallback: no pool was resolved (no app matched or no pools configured)
184
+ return Err(Error::explain(
185
+ ErrorType::InternalError,
186
+ "no upstream pool resolved",
187
+ ));
188
+ };
133
189
 
134
- let sni = cluster
135
- .sni
136
- .clone()
137
- .or_else(|| session.req_header().uri.host().map(|h| h.to_string()))
138
- .or_else(|| {
139
- session
140
- .req_header()
141
- .headers
142
- .get("host")
143
- .and_then(|v| v.to_str().ok())
144
- .map(|v| v.split(':').next().unwrap_or(v).to_string())
145
- })
146
- .unwrap_or_default();
147
-
148
- let enable_tls = *self
149
- .upstreams_tls
150
- .get(&upstream.addr.to_string())
151
- .unwrap_or(&false);
152
-
153
- match upstream.addr {
154
- PingoraSocketAddr::Inet(_) => Ok(Box::new(HttpPeer::new(upstream, enable_tls, sni))),
155
- PingoraSocketAddr::Unix(addr) => {
156
- let path = addr.as_pathname().and_then(|p| p.to_str()).ok_or_else(|| {
157
- Error::create(
158
- ErrorType::InternalError,
159
- ErrorSource::Internal,
160
- Some(ImmutStr::Static("invalid unix socket path")),
161
- None,
162
- )
163
- })?;
164
- let peer = HttpPeer::new_uds(path, enable_tls, sni).map_err(|e| {
165
- Error::create(
166
- ErrorType::InternalError,
167
- ErrorSource::Internal,
168
- Some(ImmutStr::Static("failed to create uds peer")),
169
- Some(Box::new(e)),
170
- )
171
- })?;
172
- Ok(Box::new(peer))
173
- }
190
+ let Some(backend) = resolved.lb.select(b"", 8) else {
191
+ return Err(Error::explain(
192
+ ErrorType::InternalError,
193
+ "no healthy upstreams available",
194
+ ));
195
+ };
196
+
197
+ let mut peer = HttpPeer::new(
198
+ backend.addr.clone(),
199
+ resolved.secure,
200
+ resolved.verify_hostname.clone(),
201
+ );
202
+ // For plaintext HTTP/2 upstreams (h2c), Pingora needs the peer's min HTTP version to be 2,
203
+ // otherwise it will assume HTTP/1.1 when no ALPN is present.
204
+ if resolved.is_http2 {
205
+ peer.options.set_http_version(2, 2);
206
+ } else {
207
+ peer.options.set_http_version(1, 1);
174
208
  }
209
+
210
+ Ok(Box::new(peer))
175
211
  }
176
212
 
177
213
  async fn upstream_request_filter(
178
214
  &self,
179
- session: &mut Session,
215
+ _session: &mut Session,
180
216
  upstream_request: &mut RequestHeader,
181
217
  ctx: &mut Self::CTX,
182
218
  ) -> Result<()> {
183
- let Some(app_info) = ctx.app_info(session) else {
219
+ let Some(seg) = ctx.path_rewrite_segment.as_deref() else {
184
220
  return Ok(());
185
221
  };
186
222
 
187
- if let Some(client_addr) = session.client_addr()
188
- && let Some(inet) = client_addr.as_inet()
189
- {
190
- let client_ip = inet.ip().to_string();
191
- let new_val = if let Some(existing) = upstream_request.headers.get("x-forwarded-for") {
192
- if let Ok(existing_str) = existing.to_str() {
193
- format!("{}, {}", existing_str, client_ip)
194
- } else {
195
- client_ip
196
- }
197
- } else {
198
- client_ip
199
- };
200
- upstream_request
201
- .insert_header("x-forwarded-for", new_val)
202
- .map_err(|e| {
203
- Error::create(
204
- ErrorType::InternalError,
205
- ErrorSource::Internal,
206
- Some(ImmutStr::Static("failed to set x-forwarded-for")),
207
- Some(Box::new(e)),
208
- )
209
- })?;
210
- }
211
-
212
- let name = app_info.name.as_str();
223
+ let Some(path_and_query) = upstream_request.uri.path_and_query().map(|pq| pq.as_str())
224
+ else {
225
+ return Ok(());
226
+ };
213
227
 
214
- let path = upstream_request.uri.path();
215
- let path_bytes = path.as_bytes();
216
- let mut start_idx = 0;
217
- while start_idx < path_bytes.len() && path_bytes[start_idx] == b'/' {
218
- start_idx += 1;
219
- }
228
+ let Some(new_path_and_query) = strip_first_path_segment(path_and_query, seg) else {
229
+ return Ok(());
230
+ };
220
231
 
221
- if path[start_idx..].starts_with(name) {
222
- let end_idx = start_idx + name.len();
223
- if end_idx == path.len() || path_bytes[end_idx] == b'/' {
224
- let mut new_path = &path[end_idx..];
225
- if new_path.is_empty() {
226
- new_path = "/";
227
- }
228
-
229
- let mut parts = upstream_request.uri.clone().into_parts();
230
- let path_and_query = if let Some(query) = upstream_request.uri.query() {
231
- let mut s = String::with_capacity(new_path.len() + 1 + query.len());
232
- s.push_str(new_path);
233
- s.push('?');
234
- s.push_str(query);
235
- s
236
- } else {
237
- new_path.to_string()
238
- };
239
-
240
- let pq = path_and_query.parse::<PathAndQuery>().map_err(|e| {
241
- Error::create(
242
- ErrorType::InternalError,
243
- ErrorSource::Internal,
244
- Some(ImmutStr::Static("invalid path")),
245
- Some(Box::new(e)),
246
- )
247
- })?;
248
-
249
- parts.path_and_query = Some(pq);
250
- let new_uri = Uri::from_parts(parts).map_err(|e| {
251
- Error::create(
252
- ErrorType::InternalError,
253
- ErrorSource::Internal,
254
- Some(ImmutStr::Static("invalid uri")),
255
- Some(Box::new(e)),
256
- )
257
- })?;
258
-
259
- debug!(
260
- "Rewriting upstream URI from {} to {}",
261
- upstream_request.uri, new_uri
262
- );
263
- upstream_request.set_uri(new_uri);
264
- }
265
- }
232
+ let uri = Uri::builder()
233
+ .path_and_query(new_path_and_query.as_ref())
234
+ .build()
235
+ .map_err(|e| {
236
+ Error::because(
237
+ ErrorType::InternalError,
238
+ "failed to rewrite upstream uri",
239
+ e,
240
+ )
241
+ })?;
242
+
243
+ upstream_request.set_uri(uri);
266
244
  Ok(())
267
245
  }
246
+ }
268
247
 
269
- async fn response_filter(
270
- &self,
271
- _session: &mut Session,
272
- _upstream_response: &mut ResponseHeader,
273
- _ctx: &mut Self::CTX,
274
- ) -> Result<()> {
275
- const REMOVE_HEADERS: [&str; 1] = ["uWebSockets"];
276
- for header in REMOVE_HEADERS {
277
- _upstream_response.remove_header(header);
278
- }
279
- Ok(())
280
- }
248
+ fn is_upgrade_request(session: &Session) -> bool {
249
+ // WebSocket/Upgrade is an HTTP/1.1 mechanism.
250
+ // Keep it simple: treat presence of `Upgrade` header as an upgrade request.
251
+ session.req_header().headers.get(header::UPGRADE).is_some()
281
252
  }
282
253
 
283
- pub struct RouterAssembly {
284
- pub router: Router,
285
- pub background_services: Vec<GenBackgroundService<Cluster>>,
254
+ fn extract_first_path_segment(session: &Session) -> Option<&str> {
255
+ let path = session.req_header().uri.path();
256
+ let mut parts = path.split('/').filter(|p| !p.is_empty());
257
+ parts.next()
286
258
  }
287
259
 
288
- pub fn build_router(config: &ProxyConfig) -> NapiResult<RouterAssembly> {
289
- let mut services = Vec::with_capacity(config.apps.len());
290
- let mut clusters = HashMap::with_capacity(config.apps.len());
291
- let mut upstreams_tls = HashMap::new();
292
-
293
- for (name, definition) in &config.apps {
294
- let mut app_clusters = HashMap::new();
295
-
296
- for (&kind, upstreams) in &definition.upstreams {
297
- let mut resolved_addrs = Vec::new();
298
- for upstream in upstreams {
299
- match upstream {
300
- AppUpstream::Port {
301
- secure,
302
- hostname,
303
- port,
304
- ..
305
- } => {
306
- let addr_str = format!("{}:{}", hostname, port);
307
- let addrs = addr_str.to_socket_addrs().map_err(|e| {
308
- NapiError::from_reason(format!(
309
- "failed to resolve '{}': {}",
310
- addr_str, e
311
- ))
312
- })?;
313
- for addr in addrs {
314
- let p_addr = PingoraSocketAddr::Inet(addr);
315
- resolved_addrs.push(p_addr.clone());
316
- upstreams_tls.insert(p_addr.to_string(), *secure);
317
- }
318
- }
319
- AppUpstream::Unix { secure, path } => {
320
- let p_addr = PingoraSocketAddr::Unix(
321
- std::os::unix::net::SocketAddr::from_pathname(path).map_err(|e| {
322
- NapiError::from_reason(format!(
323
- "failed to resolve unix socket '{}': {}",
324
- path, e
325
- ))
326
- })?,
327
- );
328
- resolved_addrs.push(p_addr.clone());
329
- upstreams_tls.insert(p_addr.to_string(), *secure);
330
- }
331
- }
332
- }
260
+ fn extract_host(session: &Session) -> Option<Cow<'_, str>> {
261
+ let headers = &session.req_header().headers;
333
262
 
334
- if resolved_addrs.is_empty() {
335
- continue;
336
- }
263
+ let host_str = headers
264
+ .get(header::HOST)
265
+ .and_then(|v| v.to_str().ok())
266
+ .or_else(|| {
267
+ headers
268
+ .get(http::HeaderName::from_static(":authority"))
269
+ .and_then(|v| v.to_str().ok())
270
+ })?;
337
271
 
338
- let cluster_name = format!("{}:{}", name, kind.as_str());
339
- let (balancer, service) =
340
- build_cluster_service(&cluster_name, resolved_addrs, config.health_check_interval)?;
341
-
342
- app_clusters.insert(
343
- kind,
344
- ClusterEntry {
345
- balancer,
346
- sni: definition.sni.clone(),
347
- },
348
- );
349
- if let Some(service) = service {
350
- services.push(service);
351
- }
272
+ let host_without_port = strip_port_str(host_str);
273
+
274
+ // Only allocate if lowercase conversion is needed
275
+ if host_without_port.chars().any(|c| c.is_ascii_uppercase()) {
276
+ Some(Cow::Owned(host_without_port.to_ascii_lowercase()))
277
+ } else {
278
+ Some(Cow::Borrowed(host_without_port))
279
+ }
280
+ }
281
+
282
+ fn strip_first_path_segment<'a>(path_and_query: &'a str, segment: &str) -> Option<Cow<'a, str>> {
283
+ let (path, query) = match path_and_query.split_once('?') {
284
+ Some((p, q)) => (p, Some(q)),
285
+ None => (path_and_query, None),
286
+ };
287
+
288
+ let prefix_len = segment.len() + 1; // "/{segment}".len()
289
+
290
+ // Check if path matches "/{segment}" exactly or starts with "/{segment}/"
291
+ if !path.starts_with('/') || path.len() < prefix_len {
292
+ return None;
293
+ }
294
+
295
+ let after_slash = &path[1..];
296
+ if !after_slash.starts_with(segment) {
297
+ return None;
298
+ }
299
+
300
+ // Check boundary: must be exact match or followed by '/'
301
+ let remainder = &path[prefix_len..];
302
+ let rewritten_path = if remainder.is_empty() {
303
+ // path == "/{segment}"
304
+ "/"
305
+ } else if remainder.starts_with('/') {
306
+ // path starts with "/{segment}/"
307
+ remainder
308
+ } else {
309
+ // path is like "/{segment}xyz" - not a boundary match
310
+ return None;
311
+ };
312
+
313
+ // If no query string, we can return a borrowed slice
314
+ match query {
315
+ None => Some(Cow::Borrowed(rewritten_path)),
316
+ Some(q) => {
317
+ // Must allocate to concatenate path + "?" + query
318
+ let mut out = String::with_capacity(rewritten_path.len() + 1 + q.len());
319
+ out.push_str(rewritten_path);
320
+ out.push('?');
321
+ out.push_str(q);
322
+ Some(Cow::Owned(out))
352
323
  }
324
+ }
325
+ }
353
326
 
354
- if !app_clusters.is_empty() {
355
- clusters.insert(name.clone(), app_clusters);
327
+ /// Strip port from host string, returning a slice (zero allocation).
328
+ fn strip_port_str(host: &str) -> &str {
329
+ // "example.com:3000" => "example.com"
330
+ // "[::1]:3000" => "[::1]"
331
+ if let Some(stripped) = host.strip_prefix('[') {
332
+ // IPv6 address: find closing bracket
333
+ if let Some(end) = stripped.find(']') {
334
+ return &host[..end + 2]; // Include brackets: "[" + content + "]"
356
335
  }
336
+ return host;
357
337
  }
358
338
 
359
- let router = Router::new(clusters, upstreams_tls);
360
- Ok(RouterAssembly {
361
- router,
362
- background_services: services,
363
- })
339
+ match host.rsplit_once(':') {
340
+ Some((h, port)) if !h.is_empty() && port.chars().all(|c| c.is_ascii_digit()) => h,
341
+ _ => host,
342
+ }
364
343
  }
365
344
 
366
- fn build_cluster_service(
367
- name: &str,
368
- upstreams: Vec<PingoraSocketAddr>,
369
- health_interval: Option<Duration>,
370
- ) -> NapiResult<(Arc<Cluster>, Option<GenBackgroundService<Cluster>>)> {
371
- info!(
372
- "Building cluster for app '{name}' with upstreams: {:?}",
373
- upstreams
374
- );
375
- let backends_vec: Vec<Backend> = upstreams
376
- .into_iter()
377
- .map(|addr| {
378
- let addr_str = addr.to_string();
379
- Backend::new(&addr_str).map_err(|e| {
380
- NapiError::from_reason(format!(
381
- "failed to create backend for '{}': {}",
382
- addr_str, e
383
- ))
384
- })
385
- })
386
- .collect::<Result<Vec<_>, _>>()?;
387
- let backends_set: BTreeSet<Backend> = backends_vec.into_iter().collect();
388
- let discovery = Static::new(backends_set);
389
- let backends = Backends::new(discovery);
390
- let mut balancer = LoadBalancer::from_backends(backends);
391
-
392
- if let Some(interval) = health_interval {
393
- balancer.set_health_check(TcpHealthCheck::new());
394
- balancer.health_check_frequency = Some(interval);
395
- let service = background_service("cluster health check", balancer);
396
- Ok((service.task(), Some(service)))
397
- } else {
398
- Ok((Arc::new(balancer), None))
345
+ #[cfg(test)]
346
+ mod tests {
347
+ use super::strip_first_path_segment;
348
+
349
+ #[test]
350
+ fn path_rewrite_strips_exact_segment() {
351
+ assert_eq!(
352
+ strip_first_path_segment("/auth", "auth").as_deref(),
353
+ Some("/")
354
+ );
355
+ assert_eq!(
356
+ strip_first_path_segment("/auth/", "auth").as_deref(),
357
+ Some("/")
358
+ );
359
+ assert_eq!(
360
+ strip_first_path_segment("/auth/login", "auth").as_deref(),
361
+ Some("/login")
362
+ );
399
363
  }
400
- }
401
364
 
402
- fn extract_app_name(session: &Session) -> Option<&str> {
403
- let path = session.req_header().uri.path();
404
- let without_slash = path.trim_start_matches('/');
405
- without_slash.split('/').find(|segment| !segment.is_empty())
365
+ #[test]
366
+ fn path_rewrite_preserves_query_string() {
367
+ assert_eq!(
368
+ strip_first_path_segment("/auth/login?x=1", "auth").as_deref(),
369
+ Some("/login?x=1")
370
+ );
371
+ }
372
+
373
+ #[test]
374
+ fn path_rewrite_is_boundary_aware() {
375
+ assert_eq!(strip_first_path_segment("/authz", "auth"), None);
376
+ assert_eq!(strip_first_path_segment("/authz/login", "auth"), None);
377
+ assert_eq!(strip_first_path_segment("/a", "auth"), None);
378
+ }
379
+
380
+ #[test]
381
+ fn strip_port_handles_ipv4() {
382
+ use super::strip_port_str;
383
+ assert_eq!(strip_port_str("example.com:3000"), "example.com");
384
+ assert_eq!(strip_port_str("example.com"), "example.com");
385
+ assert_eq!(strip_port_str("127.0.0.1:8080"), "127.0.0.1");
386
+ }
387
+
388
+ #[test]
389
+ fn strip_port_handles_ipv6() {
390
+ use super::strip_port_str;
391
+ assert_eq!(strip_port_str("[::1]:3000"), "[::1]");
392
+ assert_eq!(strip_port_str("[::1]"), "[::1]");
393
+ assert_eq!(strip_port_str("[2001:db8::1]:443"), "[2001:db8::1]");
394
+ }
395
+
396
+ #[test]
397
+ fn strip_port_edge_cases() {
398
+ use super::strip_port_str;
399
+ // Empty string
400
+ assert_eq!(strip_port_str(""), "");
401
+ // Trailing colon - empty "port" is all digits (vacuously true), so strips
402
+ assert_eq!(strip_port_str("host:"), "host");
403
+ // Non-numeric port (should not strip)
404
+ assert_eq!(strip_port_str("host:abc"), "host:abc");
405
+ // Multiple colons without brackets (last segment is port-like)
406
+ assert_eq!(strip_port_str("a:b:80"), "a:b");
407
+ // Only port number - empty host, does not strip
408
+ assert_eq!(strip_port_str(":8080"), ":8080");
409
+ // Malformed IPv6 (no closing bracket)
410
+ assert_eq!(strip_port_str("[::1"), "[::1");
411
+ // IPv6 with trailing content after bracket
412
+ assert_eq!(strip_port_str("[::1]abc"), "[::1]");
413
+ }
414
+
415
+ #[test]
416
+ fn path_rewrite_edge_cases() {
417
+ // Root path - no segment to strip
418
+ assert_eq!(strip_first_path_segment("/", "auth"), None);
419
+ // Empty segment name
420
+ assert_eq!(strip_first_path_segment("/auth", ""), None);
421
+ // Deeply nested paths
422
+ assert_eq!(
423
+ strip_first_path_segment("/auth/a/b/c/d", "auth").as_deref(),
424
+ Some("/a/b/c/d")
425
+ );
426
+ // Query string only on root segment
427
+ assert_eq!(
428
+ strip_first_path_segment("/auth?redirect=home", "auth").as_deref(),
429
+ Some("/?redirect=home")
430
+ );
431
+ // Multiple query parameters
432
+ assert_eq!(
433
+ strip_first_path_segment("/auth/login?a=1&b=2&c=3", "auth").as_deref(),
434
+ Some("/login?a=1&b=2&c=3")
435
+ );
436
+ // Path with encoded characters
437
+ assert_eq!(
438
+ strip_first_path_segment("/auth/path%20with%20spaces", "auth").as_deref(),
439
+ Some("/path%20with%20spaces")
440
+ );
441
+ // Segment with special chars (if segment itself has special chars)
442
+ assert_eq!(
443
+ strip_first_path_segment("/auth-service/login", "auth-service").as_deref(),
444
+ Some("/login")
445
+ );
446
+ }
447
+
448
+ #[test]
449
+ fn path_rewrite_returns_borrowed_when_no_query() {
450
+ use std::borrow::Cow;
451
+ // Without query string, should return Cow::Borrowed
452
+ let result = strip_first_path_segment("/auth/login", "auth");
453
+ assert!(matches!(result, Some(Cow::Borrowed(_))));
454
+
455
+ // With query string, must allocate (Cow::Owned)
456
+ let result = strip_first_path_segment("/auth/login?x=1", "auth");
457
+ assert!(matches!(result, Some(Cow::Owned(_))));
458
+ }
459
+
460
+ #[test]
461
+ fn path_rewrite_no_match_cases() {
462
+ // Completely different segment
463
+ assert_eq!(strip_first_path_segment("/users/login", "auth"), None);
464
+ // Segment is prefix but not at boundary
465
+ assert_eq!(strip_first_path_segment("/authorization", "auth"), None);
466
+ // Case sensitive - should not match
467
+ assert_eq!(strip_first_path_segment("/Auth/login", "auth"), None);
468
+ assert_eq!(strip_first_path_segment("/AUTH/login", "auth"), None);
469
+ // Missing leading slash
470
+ assert_eq!(strip_first_path_segment("auth/login", "auth"), None);
471
+ }
472
+
473
+ #[test]
474
+ fn lowercase_host_helper() {
475
+ use super::strip_port_str;
476
+ use std::borrow::Cow;
477
+
478
+ // Helper to test the lowercase Cow logic (extracted from extract_host)
479
+ fn normalize_host(host: &str) -> Cow<'_, str> {
480
+ let host_without_port = strip_port_str(host);
481
+ if host_without_port.chars().any(|c| c.is_ascii_uppercase()) {
482
+ Cow::Owned(host_without_port.to_ascii_lowercase())
483
+ } else {
484
+ Cow::Borrowed(host_without_port)
485
+ }
486
+ }
487
+
488
+ // Already lowercase - should borrow
489
+ let result = normalize_host("example.com");
490
+ assert!(matches!(result, Cow::Borrowed(_)));
491
+ assert_eq!(result, "example.com");
492
+
493
+ // Uppercase - should allocate and lowercase
494
+ let result = normalize_host("Example.COM");
495
+ assert!(matches!(result, Cow::Owned(_)));
496
+ assert_eq!(result, "example.com");
497
+
498
+ // Mixed case with port
499
+ let result = normalize_host("Example.com:8080");
500
+ assert!(matches!(result, Cow::Owned(_)));
501
+ assert_eq!(result, "example.com");
502
+
503
+ // Lowercase with port - should borrow
504
+ let result = normalize_host("example.com:8080");
505
+ assert!(matches!(result, Cow::Borrowed(_)));
506
+ assert_eq!(result, "example.com");
507
+
508
+ // IPv6 uppercase (rare but possible)
509
+ let result = normalize_host("[::1]:8080");
510
+ assert!(matches!(result, Cow::Borrowed(_)));
511
+ assert_eq!(result, "[::1]");
512
+ }
406
513
  }