osprey 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rubocop.yml +8 -0
- data/CHANGELOG.md +5 -0
- data/CODE_OF_CONDUCT.md +132 -0
- data/Cargo.lock +1350 -0
- data/Cargo.toml +7 -0
- data/LICENSE.txt +21 -0
- data/README.md +63 -0
- data/Rakefile +22 -0
- data/exe/osprey +77 -0
- data/ext/osprey/Cargo.toml +28 -0
- data/ext/osprey/extconf.rb +6 -0
- data/ext/osprey/src/.DS_Store +0 -0
- data/ext/osprey/src/lib.rs +1659 -0
- data/ext/osprey/src/structs/mod.rs +2 -0
- data/ext/osprey/src/structs/rack_request.rs +168 -0
- data/ext/osprey/src/structs/rack_response.rs +152 -0
- data/flamegraph.svg +491 -0
- data/lib/osprey/.DS_Store +0 -0
- data/lib/osprey/patches.rb +8 -0
- data/lib/osprey/rack/handler/osprey.rb +22 -0
- data/lib/osprey/version.rb +5 -0
- data/lib/osprey.rb +53 -0
- data/server.cert +0 -0
- data/server.key +0 -0
- data/sig/hummingbird.rbs +4 -0
- data/test.crt +29 -0
- data/test.key +52 -0
- metadata +100 -0
@@ -0,0 +1,1659 @@
|
|
1
|
+
#![deny(unused_crate_dependencies)]
|
2
|
+
|
3
|
+
use bytes::Bytes;
|
4
|
+
use crossbeam::channel::{Receiver, Sender};
|
5
|
+
use http_body_util::combinators::BoxBody;
|
6
|
+
use http_body_util::{BodyExt, Empty};
|
7
|
+
use hyper::body::{Body, Incoming};
|
8
|
+
use hyper::header::{HeaderName, HeaderValue};
|
9
|
+
use hyper::service::service_fn;
|
10
|
+
use hyper::upgrade::Upgraded;
|
11
|
+
use hyper::{HeaderMap, Request, Response, StatusCode};
|
12
|
+
use hyper_util::rt::{TokioExecutor, TokioIo};
|
13
|
+
use hyper_util::server::conn::auto::Builder;
|
14
|
+
use libc::{sighandler_t, sleep};
|
15
|
+
use log::{debug, error, info, warn};
|
16
|
+
use magnus::value::{Id, Lazy, LazyId, ReprValue};
|
17
|
+
use magnus::{
|
18
|
+
block::Proc,
|
19
|
+
exception, function,
|
20
|
+
scan_args::{get_kwargs, scan_args, Args, KwArgs},
|
21
|
+
Error, Object, RHash, Value,
|
22
|
+
};
|
23
|
+
use magnus::{eval, Class, Module, RModule, Ruby};
|
24
|
+
use magnus::{RArray, RClass};
|
25
|
+
use nix::cmsg_space;
|
26
|
+
use nix::sys::socket::{
|
27
|
+
recvmsg, sendmsg, socketpair, AddressFamily, ControlMessage, ControlMessageOwned, MsgFlags,
|
28
|
+
SockFlag, SockType,
|
29
|
+
};
|
30
|
+
use std::cell::RefCell;
|
31
|
+
use std::convert::Infallible;
|
32
|
+
use std::os::unix::io::RawFd;
|
33
|
+
use std::os::unix::net::UnixStream;
|
34
|
+
use std::ptr::null_mut;
|
35
|
+
use structs::rack_request::{define_request_info, RackRequest, RackRequestInner};
|
36
|
+
use tokio::io::unix::AsyncFd;
|
37
|
+
use tokio::io::{AsyncRead, AsyncWrite};
|
38
|
+
use tokio::sync::broadcast::channel;
|
39
|
+
use tokio_stream::StreamExt;
|
40
|
+
|
41
|
+
use rb_sys::{rb_thread_call_with_gvl, rb_thread_call_without_gvl, rb_thread_create};
|
42
|
+
use simple_logger::SimpleLogger;
|
43
|
+
use socket2::{Domain, Protocol, Socket, Type};
|
44
|
+
use std::io::{IoSlice, IoSliceMut};
|
45
|
+
use std::net::{TcpListener, TcpStream};
|
46
|
+
use std::os::fd::{FromRawFd, IntoRawFd, OwnedFd};
|
47
|
+
use std::panic;
|
48
|
+
use std::pin::Pin;
|
49
|
+
use std::sync::atomic::Ordering;
|
50
|
+
use std::sync::atomic::{AtomicBool, AtomicUsize};
|
51
|
+
use std::sync::{Mutex, RwLock};
|
52
|
+
use std::{
|
53
|
+
collections::HashMap, ffi::c_void, fs::OpenOptions, net::SocketAddr, os::fd::AsRawFd,
|
54
|
+
sync::Arc, time::Duration,
|
55
|
+
};
|
56
|
+
use structs::rack_response::{define_collector, RackResponse};
|
57
|
+
use tokio::net::{TcpListener as TokioTcpListener, TcpStream as TokioTcpStream};
|
58
|
+
use tokio::runtime::Builder as RuntimeBuilder;
|
59
|
+
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
|
60
|
+
use tokio::sync::oneshot::{self};
|
61
|
+
use tokio::task::{JoinError, JoinHandle};
|
62
|
+
use tokio::time::sleep as TokioSleep;
|
63
|
+
use tokio_rustls::rustls::ServerConfig;
|
64
|
+
use tokio_rustls::TlsAcceptor;
|
65
|
+
|
66
|
+
mod structs;
|
67
|
+
|
68
|
+
trait IoStream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {}
|
69
|
+
impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin> IoStream for T {}
|
70
|
+
|
71
|
+
/// --------------------------------------------------
|
72
|
+
/// Global/Static shared data for forking & shutdown
|
73
|
+
/// --------------------------------------------------
|
74
|
+
|
75
|
+
static CHILDREN: RwLock<Vec<ChildProcess>> = RwLock::new(vec![]);
|
76
|
+
|
77
|
+
static SHUTDOWN_REQUESTED: AtomicBool = AtomicBool::new(false);
|
78
|
+
static CTRL_C_SIGNAL_CHANNEL: RwLock<Option<tokio::sync::broadcast::Sender<()>>> =
|
79
|
+
RwLock::new(None);
|
80
|
+
|
81
|
+
pub static OSPREY: Lazy<RModule> = Lazy::new(|ruby| ruby.define_module("Osprey").unwrap());
|
82
|
+
pub static ID_SCHEDULE: LazyId = LazyId::new("schedule");
|
83
|
+
pub static ID_CALL: LazyId = LazyId::new("call");
|
84
|
+
pub static ID_NEW: LazyId = LazyId::new("new");
|
85
|
+
pub static ID_PUSH: LazyId = LazyId::new("push");
|
86
|
+
pub static ID_RESUME: LazyId = LazyId::new("resume");
|
87
|
+
pub static ID_TRANSFORM_VALUES: LazyId = LazyId::new("transform_values");
|
88
|
+
pub static ARRAY_PROC: Lazy<Proc> = Lazy::new(|ruby| {
|
89
|
+
ruby.module_kernel()
|
90
|
+
.funcall::<&str, (&str,), Value>("method", ("Array",))
|
91
|
+
.unwrap()
|
92
|
+
.funcall("to_proc", ())
|
93
|
+
.unwrap()
|
94
|
+
});
|
95
|
+
pub static WORK_QUEUE: Lazy<Value> = Lazy::new(|ruby| {
|
96
|
+
let queue = ruby
|
97
|
+
.define_class("Queue", ruby.class_object())
|
98
|
+
.unwrap()
|
99
|
+
.funcall::<&str, (), Value>("new", ())
|
100
|
+
.unwrap();
|
101
|
+
ruby.get_inner(&OSPREY)
|
102
|
+
.ivar_set("@work_queue", queue)
|
103
|
+
.unwrap();
|
104
|
+
queue
|
105
|
+
});
|
106
|
+
|
107
|
+
static SCHEDULER_CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
108
|
+
ruby.define_module("Async")
|
109
|
+
.unwrap()
|
110
|
+
.const_get("Scheduler")
|
111
|
+
.unwrap()
|
112
|
+
});
|
113
|
+
|
114
|
+
static FIBER_CLASS: Lazy<RClass> =
|
115
|
+
Lazy::new(|ruby| ruby.define_class("Fiber", ruby.class_object()).unwrap());
|
116
|
+
/// A simple struct to hold the server config + the Ruby app `Proc`
|
117
|
+
#[derive(Clone)]
|
118
|
+
struct RackServerConfig {
|
119
|
+
host: String,
|
120
|
+
port: u16,
|
121
|
+
http_port: Option<u16>,
|
122
|
+
workers: u16,
|
123
|
+
threads: u16,
|
124
|
+
cert_path: Option<String>,
|
125
|
+
key_path: Option<String>,
|
126
|
+
shutdown_timeout: Duration,
|
127
|
+
script_name: String,
|
128
|
+
before_fork: Option<Proc>,
|
129
|
+
after_fork: Option<Proc>,
|
130
|
+
}
|
131
|
+
|
132
|
+
#[derive(Debug, Clone)]
|
133
|
+
struct ChildProcess {
|
134
|
+
pid: i32,
|
135
|
+
writer: Arc<OwnedFd>,
|
136
|
+
}
|
137
|
+
|
138
|
+
struct AcceptLoopArgs {
|
139
|
+
http_listener: Option<TcpListener>,
|
140
|
+
https_listener: Option<TcpListener>,
|
141
|
+
shutdown_timeout: Duration,
|
142
|
+
threads: u16,
|
143
|
+
reader: Option<OwnedFd>,
|
144
|
+
script_name: String,
|
145
|
+
cert_path: Option<String>,
|
146
|
+
key_path: Option<String>,
|
147
|
+
}
|
148
|
+
|
149
|
+
unsafe extern "C" fn accept_loop_interrupt(_ptr: *mut c_void) {
|
150
|
+
info!("(PID={}) Interrupt accept loop", libc::getpid());
|
151
|
+
}
|
152
|
+
|
153
|
+
unsafe fn setup_listeners(
|
154
|
+
server: &RackServerConfig,
|
155
|
+
) -> Result<(Option<TcpListener>, Option<TcpListener>), Error> {
|
156
|
+
let mut http_listener = None;
|
157
|
+
let mut https_listener = None;
|
158
|
+
|
159
|
+
// Set up HTTP listener if http_port is provided
|
160
|
+
if let Some(http_port) = server.http_port {
|
161
|
+
let socket = Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP)).map_err(|e| {
|
162
|
+
Error::new(
|
163
|
+
exception::io_error(),
|
164
|
+
format!("Socket creation failed: {}", e),
|
165
|
+
)
|
166
|
+
})?;
|
167
|
+
socket.set_reuse_address(true).ok();
|
168
|
+
socket.set_reuse_port(true).ok();
|
169
|
+
socket.set_nonblocking(true).ok();
|
170
|
+
socket.set_nodelay(true).ok();
|
171
|
+
socket.set_recv_buffer_size(1048576).ok();
|
172
|
+
|
173
|
+
let address: SocketAddr =
|
174
|
+
format!("{}:{}", server.host, http_port)
|
175
|
+
.parse()
|
176
|
+
.map_err(|e| {
|
177
|
+
Error::new(
|
178
|
+
exception::io_error(),
|
179
|
+
format!(
|
180
|
+
"Failed to parse address {}: {}",
|
181
|
+
format!("{}:{}", server.host, http_port),
|
182
|
+
e
|
183
|
+
),
|
184
|
+
)
|
185
|
+
})?;
|
186
|
+
socket.bind(&address.into()).map_err(|e| {
|
187
|
+
Error::new(
|
188
|
+
exception::io_error(),
|
189
|
+
format!("bind failed for HTTP: {}", e),
|
190
|
+
)
|
191
|
+
})?;
|
192
|
+
socket.listen(1024).map_err(|e| {
|
193
|
+
Error::new(
|
194
|
+
exception::io_error(),
|
195
|
+
format!("listen failed for HTTP: {}", e),
|
196
|
+
)
|
197
|
+
})?;
|
198
|
+
let listener: TcpListener = socket.into();
|
199
|
+
listener.set_nonblocking(true).map_err(|e| {
|
200
|
+
Error::new(
|
201
|
+
exception::io_error(),
|
202
|
+
format!("set_nonblocking failed for HTTP: {}", e),
|
203
|
+
)
|
204
|
+
})?;
|
205
|
+
info!("(PID={}) Listening on http://{}", libc::getpid(), address);
|
206
|
+
http_listener = Some(listener);
|
207
|
+
}
|
208
|
+
|
209
|
+
// Set up HTTPS listener if cert_path and key_path are provided
|
210
|
+
if server.cert_path.is_some() && server.key_path.is_some() {
|
211
|
+
let socket = Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP)).map_err(|e| {
|
212
|
+
Error::new(
|
213
|
+
exception::io_error(),
|
214
|
+
format!("Socket creation failed: {}", e),
|
215
|
+
)
|
216
|
+
})?;
|
217
|
+
socket.set_reuse_address(true).ok();
|
218
|
+
socket.set_reuse_port(true).ok();
|
219
|
+
socket.set_nonblocking(true).ok();
|
220
|
+
socket.set_nodelay(true).ok();
|
221
|
+
socket.set_recv_buffer_size(1048576).ok();
|
222
|
+
|
223
|
+
let address: SocketAddr =
|
224
|
+
format!("{}:{}", server.host, server.port)
|
225
|
+
.parse()
|
226
|
+
.map_err(|e| {
|
227
|
+
Error::new(
|
228
|
+
exception::io_error(),
|
229
|
+
format!(
|
230
|
+
"Failed to parse address {}: {}",
|
231
|
+
format!("{}:{}", server.host, server.port),
|
232
|
+
e
|
233
|
+
),
|
234
|
+
)
|
235
|
+
})?;
|
236
|
+
socket.bind(&address.into()).map_err(|e| {
|
237
|
+
Error::new(
|
238
|
+
exception::io_error(),
|
239
|
+
format!("bind failed for HTTPS: {}", e),
|
240
|
+
)
|
241
|
+
})?;
|
242
|
+
socket.listen(1024).map_err(|e| {
|
243
|
+
Error::new(
|
244
|
+
exception::io_error(),
|
245
|
+
format!("listen failed for HTTPS: {}", e),
|
246
|
+
)
|
247
|
+
})?;
|
248
|
+
let listener: TcpListener = socket.into();
|
249
|
+
listener.set_nonblocking(true).map_err(|e| {
|
250
|
+
Error::new(
|
251
|
+
exception::io_error(),
|
252
|
+
format!("set_nonblocking failed for HTTP: {}", e),
|
253
|
+
)
|
254
|
+
})?;
|
255
|
+
info!("(PID={}) Listening on https://{}", libc::getpid(), address);
|
256
|
+
https_listener = Some(listener);
|
257
|
+
}
|
258
|
+
|
259
|
+
Ok((http_listener, https_listener))
|
260
|
+
}
|
261
|
+
|
262
|
+
async fn await_fd(fd: RawFd) -> Result<(), JoinError> {
|
263
|
+
let async_fd_result = AsyncFd::new(fd);
|
264
|
+
|
265
|
+
if let Ok(async_fd) = async_fd_result {
|
266
|
+
return tokio::spawn(async move {
|
267
|
+
loop {
|
268
|
+
match async_fd.readable().await {
|
269
|
+
Ok(guard) => {
|
270
|
+
if guard.ready().is_read_closed() {
|
271
|
+
break;
|
272
|
+
} else {
|
273
|
+
let _ = TokioSleep(Duration::from_millis(1000)).await;
|
274
|
+
}
|
275
|
+
}
|
276
|
+
Err(e) => {
|
277
|
+
println!("Error waiting for readiness: {:?}", e);
|
278
|
+
break;
|
279
|
+
}
|
280
|
+
}
|
281
|
+
}
|
282
|
+
})
|
283
|
+
.await;
|
284
|
+
}
|
285
|
+
Ok(())
|
286
|
+
}
|
287
|
+
|
288
|
+
async unsafe fn accept_loop_async(serve_args: &mut AcceptLoopArgs) -> Result<(), Error> {
|
289
|
+
let tls_listener = serve_args.https_listener.as_ref().map(|listener| {
|
290
|
+
TokioTcpListener::from_std(listener.try_clone().expect("Clone STD Listener"))
|
291
|
+
.expect("Construct Tokio Listener from STD")
|
292
|
+
});
|
293
|
+
let non_tls_listener = serve_args.http_listener.as_ref().map(|listener| {
|
294
|
+
TokioTcpListener::from_std(listener.try_clone().expect("Clone STD Listener"))
|
295
|
+
.expect("Construct Tokio Listener from STD")
|
296
|
+
});
|
297
|
+
|
298
|
+
let (signal_tx, mut signal_rx) = channel(1);
|
299
|
+
let (shutdown_tx, mut shutdown_rx) = channel(1);
|
300
|
+
*CTRL_C_SIGNAL_CHANNEL.write().unwrap() = Some(signal_tx);
|
301
|
+
let children = CHILDREN.read().unwrap().to_vec();
|
302
|
+
let mut child_cycle = children.iter().cycle();
|
303
|
+
|
304
|
+
tokio::spawn(async move {
|
305
|
+
signal_rx.recv().await.ok();
|
306
|
+
info!(
|
307
|
+
"(PID={}) Signal received. Initiating shutdown",
|
308
|
+
libc::getpid()
|
309
|
+
);
|
310
|
+
shutdown_tx.send(()).ok();
|
311
|
+
});
|
312
|
+
|
313
|
+
loop {
|
314
|
+
tokio::select! {
|
315
|
+
conn = async { if let Some(listener) = tls_listener.as_ref() { Some(listener.accept().await) } else { None } }, if tls_listener.is_some() => {
|
316
|
+
if let Some(Ok((stream, _peer_addr))) = conn {
|
317
|
+
let fd = stream.into_std().expect("Failed to convert to std FD").into_raw_fd();
|
318
|
+
|
319
|
+
let iov = [IoSlice::new(b"1")];
|
320
|
+
let next_child = child_cycle.next().expect("Next child process");
|
321
|
+
sendmsg::<()>(
|
322
|
+
next_child.writer.as_raw_fd() as i32,
|
323
|
+
&iov,
|
324
|
+
&[ControlMessage::ScmRights(&[fd])], // Send the file descriptor
|
325
|
+
MsgFlags::empty(),
|
326
|
+
None, // No address
|
327
|
+
).expect("Failed to send file descriptor");
|
328
|
+
|
329
|
+
if SHUTDOWN_REQUESTED.load(Ordering::Relaxed) {
|
330
|
+
break;
|
331
|
+
}
|
332
|
+
}
|
333
|
+
},
|
334
|
+
conn = async { if let Some(listener) = non_tls_listener.as_ref() { Some(listener.accept().await) } else { None } }, if non_tls_listener.is_some() => {
|
335
|
+
if let Some(Ok((stream, _peer_addr))) = conn {
|
336
|
+
let fd = stream.into_std().expect("Failed to convert to std FD").into_raw_fd();
|
337
|
+
let iov = [IoSlice::new(b"0")];
|
338
|
+
let next_child = child_cycle.next().expect("Next child process");
|
339
|
+
sendmsg::<()>(
|
340
|
+
next_child.writer.as_raw_fd() as i32,
|
341
|
+
&iov,
|
342
|
+
&[ControlMessage::ScmRights(&[fd])], // Send the file descriptor
|
343
|
+
MsgFlags::empty(),
|
344
|
+
None, // No address
|
345
|
+
).expect("Failed to send file descriptor");
|
346
|
+
if SHUTDOWN_REQUESTED.load(Ordering::Relaxed) {
|
347
|
+
break;
|
348
|
+
}
|
349
|
+
}
|
350
|
+
}
|
351
|
+
_ = shutdown_rx.recv() => {
|
352
|
+
info!("(PID={}) Shutdown signal received. Breaking out of accept loop", libc::getpid());
|
353
|
+
break;
|
354
|
+
}
|
355
|
+
}
|
356
|
+
}
|
357
|
+
|
358
|
+
info!(
|
359
|
+
"(PID={}) Waiting up to {:?} for in-flight requests to complete",
|
360
|
+
libc::getpid(),
|
361
|
+
serve_args.shutdown_timeout
|
362
|
+
);
|
363
|
+
|
364
|
+
debug!("(PID={}) Dropping listeners", libc::getpid());
|
365
|
+
drop(tls_listener);
|
366
|
+
drop(non_tls_listener);
|
367
|
+
initiate_shutdown(0, 0);
|
368
|
+
|
369
|
+
Ok(())
|
370
|
+
}
|
371
|
+
|
372
|
+
use std::io;
|
373
|
+
use tokio::{io::AsyncWriteExt, net::UnixStream as TokioUnixStream};
|
374
|
+
|
375
|
+
async fn bridge_with_logging(
|
376
|
+
mut from: (impl AsyncRead + Unpin),
|
377
|
+
mut to: (impl AsyncWrite + Unpin),
|
378
|
+
direction_label: &str,
|
379
|
+
) -> io::Result<u64> {
|
380
|
+
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
381
|
+
let mut buf = [0u8; 4096];
|
382
|
+
let mut total_bytes = 0u64;
|
383
|
+
|
384
|
+
loop {
|
385
|
+
let n = from.read(&mut buf).await?;
|
386
|
+
if n == 0 {
|
387
|
+
// EOF => no more data
|
388
|
+
break;
|
389
|
+
}
|
390
|
+
println!(
|
391
|
+
"[bridge {}] Read {} bytes: {:?}",
|
392
|
+
direction_label,
|
393
|
+
n,
|
394
|
+
&buf[..n]
|
395
|
+
);
|
396
|
+
|
397
|
+
to.write_all(&buf[..n]).await?;
|
398
|
+
total_bytes += n as u64;
|
399
|
+
}
|
400
|
+
Ok(total_bytes)
|
401
|
+
}
|
402
|
+
|
403
|
+
async fn two_way_bridge(upgraded: Upgraded, local: TokioUnixStream) -> io::Result<()> {
|
404
|
+
// Wrap the `Upgraded` with `TokioIo` so it implements AsyncRead+AsyncWrite
|
405
|
+
let client_io = TokioIo::new(upgraded);
|
406
|
+
|
407
|
+
// Split each side
|
408
|
+
let (mut lr, mut lw) = tokio::io::split(local);
|
409
|
+
let (mut cr, mut cw) = tokio::io::split(client_io);
|
410
|
+
|
411
|
+
let to_ruby = tokio::spawn(async move {
|
412
|
+
if let Err(e) = tokio::io::copy(&mut cr, &mut lw).await {
|
413
|
+
eprintln!("Error copying upgraded->local: {:?}", e);
|
414
|
+
}
|
415
|
+
});
|
416
|
+
let from_ruby = tokio::spawn(async move {
|
417
|
+
if let Err(e) = tokio::io::copy(&mut lr, &mut cw).await {
|
418
|
+
eprintln!("Error copying upgraded->local: {:?}", e);
|
419
|
+
}
|
420
|
+
});
|
421
|
+
|
422
|
+
let _ = to_ruby.await;
|
423
|
+
let _ = from_ruby.await;
|
424
|
+
Ok(())
|
425
|
+
}
|
426
|
+
|
427
|
+
async unsafe fn handle_hyper_request(
|
428
|
+
req: Request<Incoming>,
|
429
|
+
peer_addr: SocketAddr,
|
430
|
+
sender: Sender<RackRequest>,
|
431
|
+
script_name: String,
|
432
|
+
) -> Result<Response<BoxBody<Bytes, Infallible>>, &'static str> {
|
433
|
+
let (parts, body) = req.into_parts();
|
434
|
+
|
435
|
+
let has_body = !body.is_end_stream();
|
436
|
+
|
437
|
+
// 1. (Optional) If there's a body, we'll expose it as an IO into our Ruby process.
|
438
|
+
let (local_fd, remote_fd) = if has_body {
|
439
|
+
let (local, remote) = TokioUnixStream::pair().map_err(|_| "socketpair failed")?;
|
440
|
+
let local_fd = local.as_raw_fd();
|
441
|
+
let (_lr, mut lw) = tokio::io::split(local);
|
442
|
+
let mut data_stream = body.into_data_stream();
|
443
|
+
|
444
|
+
// Stream all incoming data into our remote FD (Opened inside Ruby as an IO)
|
445
|
+
tokio::spawn(async move {
|
446
|
+
while let Some(Ok(chunk)) = data_stream.next().await {
|
447
|
+
info!("Got write chunk ! {:?}", chunk);
|
448
|
+
if lw.write_all(&chunk).await.is_err() {
|
449
|
+
break;
|
450
|
+
}
|
451
|
+
}
|
452
|
+
info!("Relinquishing ownership of local_fd");
|
453
|
+
_lr.unsplit(lw).into_std().unwrap().into_raw_fd()
|
454
|
+
});
|
455
|
+
|
456
|
+
let remote_fd = remote.into_std().unwrap().into_raw_fd();
|
457
|
+
(Some(local_fd), remote_fd)
|
458
|
+
} else {
|
459
|
+
(None, -1)
|
460
|
+
};
|
461
|
+
|
462
|
+
// 2) Build and send a RackRequest to our Rack app.
|
463
|
+
let (rack_req, resp_rx) =
|
464
|
+
map_hyper_request(parts.clone(), peer_addr, script_name, remote_fd).await;
|
465
|
+
sender.send(rack_req).expect("Failed to send to Ruby");
|
466
|
+
|
467
|
+
// 3) Then wait for our Rack app to return a RackResponse
|
468
|
+
let rack_response = resp_rx.await.map_err(|_| "Ruby channel closed")?;
|
469
|
+
|
470
|
+
// 4) We return a negative status if the response was hijacked. We'll reuse the body pipe
|
471
|
+
// if it already exists, otherwise, we can count on the hijacked response to have created a new one.
|
472
|
+
if rack_response.status < 0 {
|
473
|
+
let local_fd = match local_fd {
|
474
|
+
Some(ld) => ld,
|
475
|
+
None => (rack_response.status * -1) as i32,
|
476
|
+
};
|
477
|
+
|
478
|
+
return match handle_hijacked_response(local_fd, parts).await {
|
479
|
+
Ok(resp) => Ok(resp),
|
480
|
+
Err(_) => Err("Error handling hijacked response"),
|
481
|
+
};
|
482
|
+
}
|
483
|
+
|
484
|
+
// 5) Otherwise, normal HTTP response
|
485
|
+
Ok(rack_response.into_hyper_response())
|
486
|
+
}
|
487
|
+
|
488
|
+
async fn handle_connection(
|
489
|
+
input_stream: Pin<Box<tokio::net::TcpStream>>,
|
490
|
+
tls_acceptor: Option<TlsAcceptor>,
|
491
|
+
server: Arc<Builder<TokioExecutor>>,
|
492
|
+
peer_addr: SocketAddr,
|
493
|
+
sender: Sender<RackRequest>,
|
494
|
+
script_name: String,
|
495
|
+
) {
|
496
|
+
unsafe {
|
497
|
+
debug!("(PID={}) Handling connection", libc::getpid());
|
498
|
+
}
|
499
|
+
let stream: TokioIo<_> = match &tls_acceptor {
|
500
|
+
Some(acceptor) => match acceptor.accept(input_stream).await {
|
501
|
+
Ok(tls_stream) => TokioIo::new(Box::pin(tls_stream) as Pin<Box<dyn IoStream>>),
|
502
|
+
Err(err) => {
|
503
|
+
warn!("TLS handshake error: {}", err);
|
504
|
+
return;
|
505
|
+
}
|
506
|
+
},
|
507
|
+
None => TokioIo::new(Box::pin(input_stream) as Pin<Box<dyn IoStream>>),
|
508
|
+
};
|
509
|
+
|
510
|
+
unsafe {
|
511
|
+
debug!("(PID={}) serve connection with upgrades", libc::getpid());
|
512
|
+
}
|
513
|
+
|
514
|
+
if let Err(err) = server
|
515
|
+
.serve_connection_with_upgrades(
|
516
|
+
stream,
|
517
|
+
service_fn(move |hyper_request| unsafe {
|
518
|
+
handle_hyper_request(
|
519
|
+
hyper_request,
|
520
|
+
peer_addr,
|
521
|
+
sender.clone(),
|
522
|
+
script_name.clone(),
|
523
|
+
)
|
524
|
+
}),
|
525
|
+
)
|
526
|
+
.await
|
527
|
+
{
|
528
|
+
debug!("Connection aborted: {:?}", err);
|
529
|
+
}
|
530
|
+
|
531
|
+
unsafe { debug!("(PID={}) Connection closed", libc::getpid()) };
|
532
|
+
}
|
533
|
+
|
534
|
+
async unsafe fn accept_ruby_connection_async(
|
535
|
+
serve_args: &mut AcceptLoopArgs,
|
536
|
+
sender: Sender<RackRequest>,
|
537
|
+
) {
|
538
|
+
let reader = serve_args.reader.as_mut().unwrap();
|
539
|
+
|
540
|
+
drop(serve_args.http_listener.take());
|
541
|
+
drop(serve_args.https_listener.take());
|
542
|
+
|
543
|
+
let buf_raw = &mut [0u8; 1024];
|
544
|
+
let mut buf = [IoSliceMut::new(buf_raw)];
|
545
|
+
let server = Arc::new(Builder::new(TokioExecutor::new()));
|
546
|
+
let mut cmsg_space = cmsg_space!(Vec<RawFd>);
|
547
|
+
|
548
|
+
let (signal_tx, mut signal_rx) = channel(1);
|
549
|
+
let (shutdown_tx, mut shutdown_rx) = channel(1);
|
550
|
+
*CTRL_C_SIGNAL_CHANNEL.write().unwrap() = Some(signal_tx);
|
551
|
+
|
552
|
+
let mut current_requests: (
|
553
|
+
UnboundedSender<JoinHandle<()>>,
|
554
|
+
UnboundedReceiver<JoinHandle<()>>,
|
555
|
+
) = tokio::sync::mpsc::unbounded_channel();
|
556
|
+
|
557
|
+
let in_flight_requests_handle = tokio::spawn(async move {
|
558
|
+
loop {
|
559
|
+
tokio::select! {
|
560
|
+
_ = signal_rx.recv() => {
|
561
|
+
info!("(PID={}) Signal received. Initiating shutdown", libc::getpid());
|
562
|
+
shutdown_tx.send(()).ok();
|
563
|
+
break
|
564
|
+
}
|
565
|
+
result = current_requests.1.recv() => {
|
566
|
+
match result {
|
567
|
+
Some(handle) => {
|
568
|
+
if let Err(e) = handle.await {
|
569
|
+
error!("Error handling request: {:?}", e);
|
570
|
+
}
|
571
|
+
},
|
572
|
+
_ => {
|
573
|
+
continue;
|
574
|
+
}
|
575
|
+
}
|
576
|
+
}
|
577
|
+
}
|
578
|
+
}
|
579
|
+
info!("(PID={}) Flushing final requests", libc::getpid());
|
580
|
+
|
581
|
+
info!(
|
582
|
+
"(PID={}) Awaiting in-flight request for clean shut down",
|
583
|
+
libc::getpid()
|
584
|
+
);
|
585
|
+
while let Ok(handle) = current_requests.1.try_recv() {
|
586
|
+
if let Err(e) = handle.await {
|
587
|
+
error!("Error handling request: {:?}", e);
|
588
|
+
}
|
589
|
+
}
|
590
|
+
});
|
591
|
+
let child_socket = Arc::new(AsyncFd::new(reader.as_raw_fd()).unwrap());
|
592
|
+
let tls_config = configure_tls(serve_args.cert_path.clone(), serve_args.key_path.clone());
|
593
|
+
let tls_acceptor = tls_config.map(TlsAcceptor::from);
|
594
|
+
|
595
|
+
info!("(PID={}) Starting main loop", libc::getpid());
|
596
|
+
|
597
|
+
loop {
|
598
|
+
tokio::select! {
|
599
|
+
guard = child_socket.readable() => {
|
600
|
+
match guard {
|
601
|
+
Ok(mut ready_guard) => {
|
602
|
+
while let Ok(msg) = recvmsg::<()>(
|
603
|
+
reader.as_raw_fd(),
|
604
|
+
&mut buf,
|
605
|
+
Some(&mut cmsg_space),
|
606
|
+
MsgFlags::MSG_DONTWAIT,
|
607
|
+
) {
|
608
|
+
if msg.bytes == 0 {
|
609
|
+
info!("(PID={}) Received EOF. Exiting", libc::getpid());
|
610
|
+
return;
|
611
|
+
}
|
612
|
+
|
613
|
+
if let Some(ControlMessageOwned::ScmRights(fds)) = msg.cmsgs().unwrap().next() {
|
614
|
+
let fd: i32 = *fds.first().expect("File descriptor");
|
615
|
+
let tls: bool = buf[0][0] == '1' as u8;
|
616
|
+
|
617
|
+
let stream = TokioTcpStream::from_std(unsafe { TcpStream::from_raw_fd(fd) })
|
618
|
+
.expect("Tokio from std");
|
619
|
+
|
620
|
+
|
621
|
+
match stream.peer_addr(){
|
622
|
+
Ok(peer_addr) => {
|
623
|
+
let server_clone = server.clone();
|
624
|
+
|
625
|
+
let script_name_clone = serve_args.script_name.clone();
|
626
|
+
let tls_acceptor = if tls { tls_acceptor.clone() } else { None };
|
627
|
+
let sender_clone = sender.clone();
|
628
|
+
let handle = tokio::spawn(async move {
|
629
|
+
handle_connection(
|
630
|
+
Box::pin(stream),
|
631
|
+
tls_acceptor,
|
632
|
+
server_clone,
|
633
|
+
peer_addr,
|
634
|
+
sender_clone,
|
635
|
+
script_name_clone
|
636
|
+
)
|
637
|
+
.await;
|
638
|
+
});
|
639
|
+
|
640
|
+
info!("(PID={}) Spawned connection handler", libc::getpid());
|
641
|
+
if let Err(e) = current_requests.0.send(handle) {
|
642
|
+
error!("Failed to enqueue request: {:?}", e);
|
643
|
+
}
|
644
|
+
}
|
645
|
+
Err(e) => {
|
646
|
+
error!("Failed to get peer address: {:?}", e);
|
647
|
+
continue;
|
648
|
+
}
|
649
|
+
}
|
650
|
+
|
651
|
+
} else {
|
652
|
+
error!("No file descriptor received");
|
653
|
+
}
|
654
|
+
}
|
655
|
+
|
656
|
+
// Clear readiness state after processing
|
657
|
+
ready_guard.clear_ready();
|
658
|
+
}
|
659
|
+
Err(e) => {
|
660
|
+
error!("(PID={}) Failed to poll readability: {:?}", libc::getpid(), e);
|
661
|
+
}
|
662
|
+
}
|
663
|
+
}
|
664
|
+
_ = shutdown_rx.recv() => {
|
665
|
+
info!("(PID={}) Shutdown signal received. Breaking out of accept loop", libc::getpid());
|
666
|
+
break;
|
667
|
+
}
|
668
|
+
}
|
669
|
+
}
|
670
|
+
|
671
|
+
tokio::select! {
|
672
|
+
_ = in_flight_requests_handle => {
|
673
|
+
debug!("(PID={}) In-flight requests handled gracefully", libc::getpid());
|
674
|
+
}
|
675
|
+
_ = tokio::time::sleep(serve_args.shutdown_timeout) => {
|
676
|
+
debug!("(PID={}) Forced shutdown timeout reached", libc::getpid());
|
677
|
+
}
|
678
|
+
}
|
679
|
+
|
680
|
+
debug!("(PID={}) Dropping listeners", libc::getpid());
|
681
|
+
initiate_shutdown(0, 0);
|
682
|
+
}
|
683
|
+
|
684
|
+
pub unsafe fn call_without_gvl<F, R>(f: F) -> R
|
685
|
+
where
|
686
|
+
F: FnOnce() -> R,
|
687
|
+
{
|
688
|
+
// This is the function that Ruby calls “in the background” with the GVL released.
|
689
|
+
extern "C" fn trampoline<F, R>(arg: *mut c_void) -> *mut c_void
|
690
|
+
where
|
691
|
+
F: FnOnce() -> R,
|
692
|
+
{
|
693
|
+
// 1) Reconstruct the Box that holds our closure
|
694
|
+
let closure_ptr = arg as *mut Option<F>;
|
695
|
+
let closure = unsafe { (*closure_ptr).take().expect("Closure already taken") };
|
696
|
+
|
697
|
+
// 2) Call the user’s closure
|
698
|
+
let result = closure();
|
699
|
+
|
700
|
+
// 3) Box up the result so we can return a pointer to it
|
701
|
+
let boxed_result = Box::new(result);
|
702
|
+
Box::into_raw(boxed_result) as *mut c_void
|
703
|
+
}
|
704
|
+
|
705
|
+
// Box up the closure so we have a stable pointer
|
706
|
+
let mut closure_opt = Some(f);
|
707
|
+
let closure_ptr = &mut closure_opt as *mut Option<F> as *mut c_void;
|
708
|
+
|
709
|
+
// 4) Actually call `rb_thread_call_without_gvl`
|
710
|
+
let raw_result_ptr =
|
711
|
+
rb_thread_call_without_gvl(Some(trampoline::<F, R>), closure_ptr, None, null_mut());
|
712
|
+
|
713
|
+
// 5) Convert the returned pointer back into R
|
714
|
+
let result_box = Box::from_raw(raw_result_ptr as *mut R);
|
715
|
+
*result_box
|
716
|
+
}
|
717
|
+
|
718
|
+
pub unsafe fn call_with_gvl<F, R>(f: F) -> R
|
719
|
+
where
|
720
|
+
F: FnOnce() -> R,
|
721
|
+
{
|
722
|
+
// This is the function that Ruby calls “in the background” with the GVL released.
|
723
|
+
extern "C" fn trampoline<F, R>(arg: *mut c_void) -> *mut c_void
|
724
|
+
where
|
725
|
+
F: FnOnce() -> R,
|
726
|
+
{
|
727
|
+
// 1) Reconstruct the Box that holds our closure
|
728
|
+
let closure_ptr = arg as *mut Option<F>;
|
729
|
+
let closure = unsafe { (*closure_ptr).take().expect("Closure already taken") };
|
730
|
+
|
731
|
+
// 2) Call the user’s closure
|
732
|
+
let result = closure();
|
733
|
+
|
734
|
+
// 3) Box up the result so we can return a pointer to it
|
735
|
+
let boxed_result = Box::new(result);
|
736
|
+
Box::into_raw(boxed_result) as *mut c_void
|
737
|
+
}
|
738
|
+
|
739
|
+
// Box up the closure so we have a stable pointer
|
740
|
+
let mut closure_opt = Some(f);
|
741
|
+
let closure_ptr = &mut closure_opt as *mut Option<F> as *mut c_void;
|
742
|
+
|
743
|
+
// 4) Actually call `rb_thread_call_without_gvl`
|
744
|
+
let raw_result_ptr = rb_thread_call_with_gvl(Some(trampoline::<F, R>), closure_ptr);
|
745
|
+
|
746
|
+
// 5) Convert the returned pointer back into R
|
747
|
+
let result_box = Box::from_raw(raw_result_ptr as *mut R);
|
748
|
+
*result_box
|
749
|
+
}
|
750
|
+
|
751
|
+
unsafe extern "C" fn accept_ruby_connection_loop(ptr: *mut c_void) -> *mut c_void {
|
752
|
+
let serve_args = &mut *(ptr as *mut AcceptLoopArgs);
|
753
|
+
let mut builder: RuntimeBuilder = RuntimeBuilder::new_current_thread();
|
754
|
+
|
755
|
+
libc::signal(libc::SIGTERM, initiate_shutdown as usize);
|
756
|
+
libc::signal(libc::SIGINT, initiate_shutdown as usize);
|
757
|
+
|
758
|
+
let runtime = builder
|
759
|
+
.thread_name("osprey-accept-runtime")
|
760
|
+
.thread_stack_size(3 * 1024 * 1024)
|
761
|
+
.enable_io()
|
762
|
+
.enable_time()
|
763
|
+
.build()
|
764
|
+
.expect("Failed to build Tokio runtime");
|
765
|
+
|
766
|
+
let (sender, receiver) = crossbeam::channel::unbounded::<RackRequest>();
|
767
|
+
start_ruby_thread_pool(serve_args.threads as usize, Arc::new(receiver));
|
768
|
+
|
769
|
+
let _ = runtime.block_on(accept_ruby_connection_async(serve_args, sender));
|
770
|
+
|
771
|
+
std::ptr::null_mut()
|
772
|
+
}
|
773
|
+
|
774
|
+
unsafe extern "C" fn accept_loop(ptr: *mut c_void) -> *mut c_void {
|
775
|
+
let serve_args = &mut *(ptr as *mut AcceptLoopArgs);
|
776
|
+
let mut builder: RuntimeBuilder = if serve_args.threads == 1 {
|
777
|
+
RuntimeBuilder::new_current_thread()
|
778
|
+
} else {
|
779
|
+
RuntimeBuilder::new_multi_thread()
|
780
|
+
};
|
781
|
+
let runtime = builder
|
782
|
+
.worker_threads(serve_args.threads as usize)
|
783
|
+
.thread_name("osprey-accept-runtime")
|
784
|
+
.thread_stack_size(3 * 1024 * 1024)
|
785
|
+
.enable_io()
|
786
|
+
.enable_time()
|
787
|
+
.build()
|
788
|
+
.expect("Failed to build Tokio runtime");
|
789
|
+
|
790
|
+
libc::signal(libc::SIGTERM, initiate_shutdown as usize);
|
791
|
+
libc::signal(libc::SIGINT, initiate_shutdown as usize);
|
792
|
+
|
793
|
+
let _ = runtime.block_on(accept_loop_async(serve_args));
|
794
|
+
runtime.shutdown_timeout(Duration::from_millis(2000));
|
795
|
+
|
796
|
+
info!(
|
797
|
+
"(PID={}) accept_loop finished, shutting down.",
|
798
|
+
libc::getpid()
|
799
|
+
);
|
800
|
+
|
801
|
+
libc::signal(libc::SIGTERM, libc::SIG_DFL);
|
802
|
+
libc::signal(libc::SIGINT, libc::SIG_DFL);
|
803
|
+
|
804
|
+
debug!(
|
805
|
+
"(PID={}) HTTP listener state: {:?}",
|
806
|
+
libc::getpid(),
|
807
|
+
serve_args.http_listener
|
808
|
+
);
|
809
|
+
debug!("(PID={}) Dropping STD listeners", libc::getpid());
|
810
|
+
drop(serve_args.http_listener.take());
|
811
|
+
drop(serve_args.https_listener.take());
|
812
|
+
|
813
|
+
std::ptr::null_mut()
|
814
|
+
}
|
815
|
+
|
816
|
+
unsafe extern "C" fn gather_requests_no_gvl(ptr: *mut c_void) -> *mut c_void {
|
817
|
+
let receiver = unsafe { &*(ptr as *const Arc<Receiver<RackRequest>>) };
|
818
|
+
|
819
|
+
let batch_size = 25;
|
820
|
+
let mut batch = Vec::with_capacity(batch_size);
|
821
|
+
|
822
|
+
// Try to collect a batch of up to `batch_size` requests
|
823
|
+
while batch.len() < batch_size {
|
824
|
+
match receiver.try_recv() {
|
825
|
+
Ok(request) => batch.push(request),
|
826
|
+
Err(_) => break,
|
827
|
+
}
|
828
|
+
}
|
829
|
+
|
830
|
+
// If the batch is empty, wait for a single request
|
831
|
+
if batch.is_empty() {
|
832
|
+
match receiver.recv_timeout(Duration::from_micros(100)) {
|
833
|
+
Ok(request) => batch.push(request),
|
834
|
+
Err(_) => (),
|
835
|
+
}
|
836
|
+
}
|
837
|
+
|
838
|
+
// Process the collected batch of requests
|
839
|
+
Box::into_raw(Box::new(batch)) as *mut c_void
|
840
|
+
}
|
841
|
+
|
842
|
+
unsafe fn gather_requests(receiver: &Arc<Receiver<RackRequest>>) -> Vec<RackRequest> {
|
843
|
+
let data_ptr = receiver as *const Arc<Receiver<RackRequest>> as *mut c_void;
|
844
|
+
let result_ptr =
|
845
|
+
rb_thread_call_without_gvl(Some(gather_requests_no_gvl), data_ptr, None, null_mut());
|
846
|
+
let result = Box::from_raw(result_ptr as *mut Vec<RackRequest>);
|
847
|
+
*result
|
848
|
+
}
|
849
|
+
|
850
|
+
unsafe extern "C" fn ruby_thread_worker(ptr: *mut c_void) -> u64 {
|
851
|
+
let ruby = Ruby::get_unchecked();
|
852
|
+
let receiver = *Box::from_raw(ptr as *mut Arc<Receiver<RackRequest>>);
|
853
|
+
let use_scheduler = true;
|
854
|
+
|
855
|
+
if use_scheduler {
|
856
|
+
let receiver_clone = receiver.clone();
|
857
|
+
let fiber_class = ruby.get_inner(&FIBER_CLASS);
|
858
|
+
let scheduler = ruby.get_inner(&SCHEDULER_CLASS).new_instance(()).unwrap();
|
859
|
+
fiber_class
|
860
|
+
.funcall::<_, _, Value>("set_scheduler", (scheduler,))
|
861
|
+
.unwrap();
|
862
|
+
let fib = ruby.proc_from_fn(move |_rb, _arg, _blk| {
|
863
|
+
let ruby = Ruby::get_unchecked();
|
864
|
+
let fiber_class = ruby.get_inner(&FIBER_CLASS);
|
865
|
+
let receiver = receiver_clone.clone();
|
866
|
+
let scheduler: Value = fiber_class.funcall("scheduler", ()).unwrap();
|
867
|
+
let id_yield = ruby.intern("yield");
|
868
|
+
let atomic_counter = Arc::new(AtomicUsize::new(0));
|
869
|
+
loop {
|
870
|
+
let current_count = atomic_counter.load(Ordering::Relaxed);
|
871
|
+
if current_count > 0 {
|
872
|
+
if let Err(_) = scheduler.funcall::<Id, (), Value>(id_yield, ()) {
|
873
|
+
break;
|
874
|
+
}
|
875
|
+
}
|
876
|
+
let batch = gather_requests(&receiver);
|
877
|
+
for request in batch {
|
878
|
+
let atomic_counter_clone = Arc::clone(&atomic_counter);
|
879
|
+
let fiber = ruby
|
880
|
+
.fiber_new_from_fn(Default::default(), move |ruby, _args, _block| {
|
881
|
+
atomic_counter_clone.fetch_add(1, Ordering::Relaxed);
|
882
|
+
request.process(ruby).expect("Error processing request");
|
883
|
+
atomic_counter_clone.fetch_sub(1, Ordering::Relaxed);
|
884
|
+
})
|
885
|
+
.unwrap();
|
886
|
+
|
887
|
+
scheduler.funcall::<_, _, Value>(*ID_RESUME, (fiber,)).ok();
|
888
|
+
}
|
889
|
+
|
890
|
+
if SHUTDOWN_REQUESTED.load(Ordering::Relaxed) {
|
891
|
+
break;
|
892
|
+
}
|
893
|
+
}
|
894
|
+
});
|
895
|
+
fiber_class
|
896
|
+
.funcall_with_block::<_, _, Value>(*ID_SCHEDULE, (), fib)
|
897
|
+
.expect("scheduled fiber");
|
898
|
+
} else {
|
899
|
+
loop {
|
900
|
+
let batch = gather_requests(&receiver);
|
901
|
+
for request in batch {
|
902
|
+
request.process(&ruby).expect("Error processing request");
|
903
|
+
}
|
904
|
+
|
905
|
+
if SHUTDOWN_REQUESTED.load(Ordering::Relaxed) {
|
906
|
+
break;
|
907
|
+
}
|
908
|
+
}
|
909
|
+
}
|
910
|
+
0
|
911
|
+
}
|
912
|
+
|
913
|
+
fn start_ruby_thread_pool(pool_size: usize, receiver: Arc<Receiver<RackRequest>>) {
|
914
|
+
for _i in 0..pool_size {
|
915
|
+
unsafe {
|
916
|
+
rb_thread_create(
|
917
|
+
Some(ruby_thread_worker),
|
918
|
+
Box::into_raw(Box::new(receiver.clone())) as *mut c_void,
|
919
|
+
)
|
920
|
+
};
|
921
|
+
}
|
922
|
+
}
|
923
|
+
|
924
|
+
unsafe fn initiate_shutdown(signum: i32, action: sighandler_t) {
|
925
|
+
if SHUTDOWN_REQUESTED.load(Ordering::Relaxed) {
|
926
|
+
return;
|
927
|
+
}
|
928
|
+
SHUTDOWN_REQUESTED.store(true, Ordering::SeqCst);
|
929
|
+
info!(
|
930
|
+
"(PID={}) Initiating shutdown with signal {}: {}",
|
931
|
+
libc::getpid(),
|
932
|
+
signum,
|
933
|
+
action
|
934
|
+
);
|
935
|
+
|
936
|
+
if let Some(sender) = CTRL_C_SIGNAL_CHANNEL.write().unwrap().take() {
|
937
|
+
let _ = sender.send(());
|
938
|
+
}
|
939
|
+
|
940
|
+
let children = CHILDREN.read().unwrap();
|
941
|
+
for child in children.iter() {
|
942
|
+
if child.pid > 0 {
|
943
|
+
let sig_result = libc::kill(child.pid, libc::SIGTERM);
|
944
|
+
if sig_result == -1 {
|
945
|
+
eprint!("Failed to send SIGTERM to child {}", child.pid);
|
946
|
+
}
|
947
|
+
let mut status = 0;
|
948
|
+
let wait_result = libc::waitpid(child.pid, &mut status, libc::WNOHANG);
|
949
|
+
if wait_result != 0 {
|
950
|
+
libc::kill(child.pid, libc::SIGKILL);
|
951
|
+
libc::waitpid(child.pid, &mut status, 0);
|
952
|
+
}
|
953
|
+
}
|
954
|
+
}
|
955
|
+
}
|
956
|
+
|
957
|
+
struct PreforkArgs {
|
958
|
+
workers: u16,
|
959
|
+
before_fork: Option<Proc>,
|
960
|
+
after_fork: Option<Proc>,
|
961
|
+
}
|
962
|
+
|
963
|
+
pub unsafe extern "C" fn kill_other_threads(
|
964
|
+
_: *mut std::os::raw::c_void,
|
965
|
+
) -> *mut std::os::raw::c_void {
|
966
|
+
let thread_class =
|
967
|
+
RClass::from_value(eval("Thread").expect("Failed to fetch Ruby Thread class"))
|
968
|
+
.expect("Failed to fetch Ruby Thread class");
|
969
|
+
let threads: RArray = thread_class
|
970
|
+
.funcall("list", ())
|
971
|
+
.expect("Failed to list Ruby threads");
|
972
|
+
let current: Value = thread_class
|
973
|
+
.funcall("current", ())
|
974
|
+
.expect("Failed to get current thread");
|
975
|
+
|
976
|
+
// Signal thread exit
|
977
|
+
for thr in threads {
|
978
|
+
let same: bool = thr
|
979
|
+
.funcall("==", (current,))
|
980
|
+
.expect("Failed to compare threads");
|
981
|
+
if !same {
|
982
|
+
let _: Option<Value> = thr.funcall("exit", ()).expect("Failed to exit thread");
|
983
|
+
}
|
984
|
+
}
|
985
|
+
|
986
|
+
// Attempt to join
|
987
|
+
for thr in threads {
|
988
|
+
let same: bool = thr
|
989
|
+
.funcall("==", (current,))
|
990
|
+
.expect("Failed to compare threads");
|
991
|
+
if !same {
|
992
|
+
let _: Option<Value> = thr
|
993
|
+
.funcall("join", (0.5_f64,))
|
994
|
+
.expect("Failed to join thread");
|
995
|
+
}
|
996
|
+
}
|
997
|
+
|
998
|
+
// Terminate
|
999
|
+
for thr in threads {
|
1000
|
+
let same: bool = thr
|
1001
|
+
.funcall("==", (current,))
|
1002
|
+
.expect("Failed to compare threads");
|
1003
|
+
if !same {
|
1004
|
+
let alive: bool = thr
|
1005
|
+
.funcall("alive?", ())
|
1006
|
+
.expect("Failed to check if thread is alive");
|
1007
|
+
if alive {
|
1008
|
+
let _: Option<Value> = thr.funcall("kill", ()).expect("Failed to kill thread");
|
1009
|
+
}
|
1010
|
+
}
|
1011
|
+
}
|
1012
|
+
|
1013
|
+
std::ptr::null_mut()
|
1014
|
+
}
|
1015
|
+
|
1016
|
+
pub unsafe extern "C" fn before_fork(ptr: *mut std::os::raw::c_void) -> *mut std::os::raw::c_void {
|
1017
|
+
let pf_args = &mut *(ptr as *mut PreforkArgs);
|
1018
|
+
debug!("(PID={}) Before fork", libc::getpid());
|
1019
|
+
if let Some(ref p) = pf_args.before_fork {
|
1020
|
+
let _ = p.call::<(), Value>(());
|
1021
|
+
};
|
1022
|
+
std::ptr::null_mut()
|
1023
|
+
}
|
1024
|
+
|
1025
|
+
pub unsafe extern "C" fn after_fork(ptr: *mut std::os::raw::c_void) -> *mut std::os::raw::c_void {
|
1026
|
+
let pf_args = &mut *(ptr as *mut PreforkArgs);
|
1027
|
+
debug!("(PID={}) After fork", libc::getpid());
|
1028
|
+
if let Some(ref p) = pf_args.after_fork {
|
1029
|
+
let _ = p.call::<(), Value>(());
|
1030
|
+
};
|
1031
|
+
std::ptr::null_mut()
|
1032
|
+
}
|
1033
|
+
|
1034
|
+
pub unsafe extern "C" fn rb_fork(_: *mut std::os::raw::c_void) -> *mut std::os::raw::c_void {
|
1035
|
+
let ruby = Ruby::get_unchecked();
|
1036
|
+
let pid: Option<i32> = ruby
|
1037
|
+
.module_kernel()
|
1038
|
+
.funcall(ruby.intern("fork"), ())
|
1039
|
+
.unwrap();
|
1040
|
+
if let Some(child_pid) = pid {
|
1041
|
+
debug!(
|
1042
|
+
"(PID={}) Forked child process with PID {}",
|
1043
|
+
libc::getpid(),
|
1044
|
+
child_pid
|
1045
|
+
);
|
1046
|
+
return child_pid as *mut std::os::raw::c_void;
|
1047
|
+
} else {
|
1048
|
+
debug!("(PID={}) New child process starting ", libc::getpid());
|
1049
|
+
std::ptr::null_mut()
|
1050
|
+
}
|
1051
|
+
}
|
1052
|
+
|
1053
|
+
unsafe extern "C" fn prefork_setup(ptr: *mut c_void) -> *mut c_void {
|
1054
|
+
let pf_args = &*(ptr as *mut PreforkArgs);
|
1055
|
+
let workers = pf_args.workers;
|
1056
|
+
if workers >= 1 {
|
1057
|
+
let mut children = Vec::with_capacity(workers as usize - 1);
|
1058
|
+
rb_thread_call_with_gvl(Some(before_fork), ptr);
|
1059
|
+
rb_thread_call_with_gvl(Some(kill_other_threads), std::ptr::null_mut());
|
1060
|
+
|
1061
|
+
for _i in 0..workers {
|
1062
|
+
let (read, write) = socketpair(
|
1063
|
+
AddressFamily::Unix,
|
1064
|
+
SockType::Stream,
|
1065
|
+
None,
|
1066
|
+
SockFlag::empty(),
|
1067
|
+
)
|
1068
|
+
.unwrap();
|
1069
|
+
|
1070
|
+
let response_ptr = rb_thread_call_with_gvl(Some(rb_fork), std::ptr::null_mut());
|
1071
|
+
let pid = response_ptr as i32;
|
1072
|
+
debug!("Got reader and writer");
|
1073
|
+
|
1074
|
+
match pid {
|
1075
|
+
0 => {
|
1076
|
+
let _ = Box::into_raw(Box::new(write));
|
1077
|
+
|
1078
|
+
rb_thread_call_with_gvl(Some(after_fork), ptr);
|
1079
|
+
let devnull = OpenOptions::new()
|
1080
|
+
.write(true)
|
1081
|
+
.open("/dev/null")
|
1082
|
+
.expect("Failed to open /dev/null");
|
1083
|
+
unsafe {
|
1084
|
+
libc::dup2(devnull.as_raw_fd(), libc::STDIN_FILENO);
|
1085
|
+
}
|
1086
|
+
return Box::into_raw(Box::new(read)) as *mut c_void;
|
1087
|
+
}
|
1088
|
+
child_pid => {
|
1089
|
+
let _ = Box::into_raw(Box::new(read));
|
1090
|
+
children.push(ChildProcess {
|
1091
|
+
pid: child_pid,
|
1092
|
+
writer: Arc::new(write),
|
1093
|
+
})
|
1094
|
+
}
|
1095
|
+
}
|
1096
|
+
}
|
1097
|
+
CHILDREN.write().unwrap().extend(children);
|
1098
|
+
}
|
1099
|
+
return std::ptr::null_mut();
|
1100
|
+
}
|
1101
|
+
|
1102
|
+
// TLS configuration
|
1103
|
+
fn configure_tls(cert_path: Option<String>, key_path: Option<String>) -> Option<Arc<ServerConfig>> {
|
1104
|
+
if cert_path.is_none() || key_path.is_none() {
|
1105
|
+
return None;
|
1106
|
+
}
|
1107
|
+
use tokio_rustls::rustls::{Certificate, PrivateKey};
|
1108
|
+
|
1109
|
+
let cert_bytes =
|
1110
|
+
std::fs::read(cert_path.expect("Expected a cert_path")).expect("Failed to read cert file");
|
1111
|
+
let key_bytes =
|
1112
|
+
std::fs::read(key_path.expect("Expected a key_path")).expect("Failed to read key file");
|
1113
|
+
|
1114
|
+
let certs = vec![Certificate(cert_bytes)];
|
1115
|
+
let key = PrivateKey(key_bytes);
|
1116
|
+
|
1117
|
+
let mut config = ServerConfig::builder()
|
1118
|
+
.with_safe_defaults()
|
1119
|
+
.with_no_client_auth()
|
1120
|
+
.with_single_cert(certs, key)
|
1121
|
+
.expect("Failed to build TLS config");
|
1122
|
+
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
1123
|
+
|
1124
|
+
Some(Arc::new(config))
|
1125
|
+
}
|
1126
|
+
|
1127
|
+
fn filter_keywords(keywords: RHash, allowed: &[&str]) -> RHash {
|
1128
|
+
let filtered = RHash::new();
|
1129
|
+
for key in keywords.funcall::<_, _, RArray>("keys", ()).unwrap() {
|
1130
|
+
let key_str: String = key.funcall("to_s", ()).unwrap();
|
1131
|
+
if allowed.contains(&key_str.as_str()) {
|
1132
|
+
filtered.aset(key, keywords.get(key).unwrap()).unwrap();
|
1133
|
+
}
|
1134
|
+
}
|
1135
|
+
filtered
|
1136
|
+
}
|
1137
|
+
/// The main "start(...)" function that Ruby calls via `Osprey.start`
|
1138
|
+
fn start_server(args: &[Value]) -> Result<(), Error> {
|
1139
|
+
SHUTDOWN_REQUESTED.store(false, Ordering::SeqCst);
|
1140
|
+
|
1141
|
+
let srv = RackServerConfig::from_args(args)?;
|
1142
|
+
|
1143
|
+
let mut http_listener: Option<TcpListener> = None;
|
1144
|
+
let mut https_listener: Option<TcpListener> = None;
|
1145
|
+
let mut reader: Option<OwnedFd> = None;
|
1146
|
+
|
1147
|
+
let mut pf_args = PreforkArgs {
|
1148
|
+
workers: srv.workers,
|
1149
|
+
before_fork: srv.before_fork,
|
1150
|
+
after_fork: srv.after_fork,
|
1151
|
+
};
|
1152
|
+
|
1153
|
+
let pf_args_ptr = &mut pf_args as *mut PreforkArgs as *mut c_void;
|
1154
|
+
|
1155
|
+
unsafe {
|
1156
|
+
info!("(PID={}) Parent process started.", libc::getpid());
|
1157
|
+
let ptr = rb_thread_call_without_gvl(
|
1158
|
+
Some(prefork_setup),
|
1159
|
+
pf_args_ptr,
|
1160
|
+
None,
|
1161
|
+
std::ptr::null_mut(),
|
1162
|
+
);
|
1163
|
+
if !ptr.is_null() {
|
1164
|
+
reader = Some(*Box::from_raw(ptr as *mut OwnedFd));
|
1165
|
+
}
|
1166
|
+
}
|
1167
|
+
|
1168
|
+
unsafe {
|
1169
|
+
let is_parent = !CHILDREN.read().unwrap().is_empty(); //|| srv.workers == 1;
|
1170
|
+
|
1171
|
+
if is_parent {
|
1172
|
+
(http_listener, https_listener) = setup_listeners(&srv)?;
|
1173
|
+
}
|
1174
|
+
let serve_args = AcceptLoopArgs {
|
1175
|
+
http_listener,
|
1176
|
+
https_listener,
|
1177
|
+
reader,
|
1178
|
+
cert_path: srv.cert_path,
|
1179
|
+
key_path: srv.key_path,
|
1180
|
+
threads: srv.threads,
|
1181
|
+
shutdown_timeout: srv.shutdown_timeout,
|
1182
|
+
script_name: srv.script_name,
|
1183
|
+
};
|
1184
|
+
let args_ptr = Box::into_raw(Box::new(serve_args)) as *mut c_void;
|
1185
|
+
if is_parent {
|
1186
|
+
call_without_gvl(|| 3 + 2);
|
1187
|
+
rb_thread_call_without_gvl(
|
1188
|
+
Some(accept_loop),
|
1189
|
+
args_ptr,
|
1190
|
+
Some(accept_loop_interrupt),
|
1191
|
+
std::ptr::null_mut(),
|
1192
|
+
);
|
1193
|
+
wait_for_children(srv.shutdown_timeout);
|
1194
|
+
CHILDREN.write().unwrap().clear();
|
1195
|
+
} else {
|
1196
|
+
rb_thread_call_without_gvl(
|
1197
|
+
Some(accept_ruby_connection_loop),
|
1198
|
+
args_ptr,
|
1199
|
+
None,
|
1200
|
+
std::ptr::null_mut(),
|
1201
|
+
);
|
1202
|
+
}
|
1203
|
+
}
|
1204
|
+
|
1205
|
+
debug!("(PID={}) accept_loop returned. Exiting.", unsafe {
|
1206
|
+
libc::getpid()
|
1207
|
+
});
|
1208
|
+
Ok(())
|
1209
|
+
}
|
1210
|
+
|
1211
|
+
unsafe fn is_pid_running(pid: i32) -> bool {
|
1212
|
+
let mut status = 0;
|
1213
|
+
let result = unsafe { libc::waitpid(pid, &mut status, libc::WNOHANG) };
|
1214
|
+
return result == 0;
|
1215
|
+
}
|
1216
|
+
|
1217
|
+
fn wait_for_children(shutdown_timeout: Duration) {
|
1218
|
+
unsafe {
|
1219
|
+
libc::signal(libc::SIGTERM, initiate_shutdown as usize);
|
1220
|
+
libc::signal(libc::SIGINT, initiate_shutdown as usize);
|
1221
|
+
}
|
1222
|
+
|
1223
|
+
for child_process in CHILDREN.read().unwrap().iter() {
|
1224
|
+
let mut status = 0;
|
1225
|
+
|
1226
|
+
loop {
|
1227
|
+
if !unsafe { is_pid_running(child_process.pid) } {
|
1228
|
+
break;
|
1229
|
+
}
|
1230
|
+
|
1231
|
+
let wait_result =
|
1232
|
+
unsafe { libc::waitpid(child_process.pid, &mut status, libc::WNOHANG) };
|
1233
|
+
|
1234
|
+
if wait_result == -1 {
|
1235
|
+
// Error occurred, handle appropriately
|
1236
|
+
|
1237
|
+
debug!(
|
1238
|
+
"(PID={}) Error waiting for child process: {}",
|
1239
|
+
child_process.pid,
|
1240
|
+
std::io::Error::last_os_error()
|
1241
|
+
);
|
1242
|
+
break;
|
1243
|
+
} else if wait_result > 0 {
|
1244
|
+
if libc::WIFSIGNALED(status) || libc::WIFSTOPPED(status) || libc::WIFEXITED(status)
|
1245
|
+
{
|
1246
|
+
break;
|
1247
|
+
}
|
1248
|
+
} else {
|
1249
|
+
unsafe {
|
1250
|
+
if SHUTDOWN_REQUESTED.load(Ordering::Relaxed) {
|
1251
|
+
debug!(
|
1252
|
+
"(PID={}) Waiting for child {} to exit. Status: {}. Wait Result: {}",
|
1253
|
+
libc::getpid(),
|
1254
|
+
child_process.pid,
|
1255
|
+
status,
|
1256
|
+
wait_result
|
1257
|
+
);
|
1258
|
+
sleep(shutdown_timeout.as_secs() as u32 + 1);
|
1259
|
+
}
|
1260
|
+
};
|
1261
|
+
}
|
1262
|
+
}
|
1263
|
+
}
|
1264
|
+
}
|
1265
|
+
|
1266
|
+
impl RackServerConfig {
|
1267
|
+
fn from_args(args: &[Value]) -> Result<Self, Error> {
|
1268
|
+
let scan_args: Args<(), (), (), (), RHash, ()> = scan_args(args)?;
|
1269
|
+
|
1270
|
+
let args: KwArgs<
|
1271
|
+
(Value,),
|
1272
|
+
(
|
1273
|
+
Option<String>,
|
1274
|
+
Option<u16>,
|
1275
|
+
Option<u16>,
|
1276
|
+
Option<u16>,
|
1277
|
+
Option<u16>,
|
1278
|
+
Option<String>,
|
1279
|
+
Option<String>,
|
1280
|
+
Option<f64>,
|
1281
|
+
Option<String>,
|
1282
|
+
),
|
1283
|
+
(),
|
1284
|
+
> = get_kwargs(
|
1285
|
+
filter_keywords(
|
1286
|
+
scan_args.keywords,
|
1287
|
+
&[
|
1288
|
+
"app",
|
1289
|
+
"host",
|
1290
|
+
"port",
|
1291
|
+
"http_port",
|
1292
|
+
"workers",
|
1293
|
+
"threads",
|
1294
|
+
"cert_path",
|
1295
|
+
"key_path",
|
1296
|
+
"shutdown_timeout",
|
1297
|
+
"script_name",
|
1298
|
+
],
|
1299
|
+
),
|
1300
|
+
&["app"],
|
1301
|
+
&[
|
1302
|
+
"host",
|
1303
|
+
"port",
|
1304
|
+
"http_port",
|
1305
|
+
"workers",
|
1306
|
+
"threads",
|
1307
|
+
"cert_path",
|
1308
|
+
"key_path",
|
1309
|
+
"shutdown_timeout",
|
1310
|
+
"script_name",
|
1311
|
+
],
|
1312
|
+
)?;
|
1313
|
+
|
1314
|
+
let args2: KwArgs<(), (Option<Proc>, Option<Proc>), ()> = get_kwargs(
|
1315
|
+
filter_keywords(scan_args.keywords, &["before_fork", "after_fork"]),
|
1316
|
+
&[],
|
1317
|
+
&["before_fork", "after_fork"],
|
1318
|
+
)?;
|
1319
|
+
|
1320
|
+
let ruby = Ruby::get().unwrap();
|
1321
|
+
let host = args.optional.0.unwrap_or("127.0.0.1".to_owned());
|
1322
|
+
let port = args.optional.1.unwrap_or(3000);
|
1323
|
+
let http_port = args.optional.2;
|
1324
|
+
let workers = args.optional.3.unwrap_or(1);
|
1325
|
+
let threads = args.optional.4.unwrap_or(1);
|
1326
|
+
let cert_path = args.optional.5;
|
1327
|
+
let key_path = args.optional.6;
|
1328
|
+
let shutdown_timeout =
|
1329
|
+
Duration::from_millis((args.optional.7.unwrap_or(5.0) * 1000.0) as u64);
|
1330
|
+
let script_name = args.optional.8.unwrap_or("/".to_owned());
|
1331
|
+
|
1332
|
+
let app = args.required.0;
|
1333
|
+
let handler: Value = app.funcall("method", (*ID_CALL,)).unwrap();
|
1334
|
+
ruby.get_inner(&OSPREY)
|
1335
|
+
.ivar_set("@handler", handler)
|
1336
|
+
.unwrap();
|
1337
|
+
ruby.get_inner(&WORK_QUEUE)
|
1338
|
+
.funcall::<&str, (), Value>("clear", ())
|
1339
|
+
.ok();
|
1340
|
+
Ok(Self {
|
1341
|
+
host,
|
1342
|
+
port,
|
1343
|
+
http_port,
|
1344
|
+
workers,
|
1345
|
+
threads,
|
1346
|
+
cert_path,
|
1347
|
+
key_path,
|
1348
|
+
shutdown_timeout,
|
1349
|
+
script_name,
|
1350
|
+
before_fork: args2.optional.0,
|
1351
|
+
after_fork: args2.optional.1,
|
1352
|
+
})
|
1353
|
+
}
|
1354
|
+
}
|
1355
|
+
/// Standard Magnus init
|
1356
|
+
#[magnus::init]
|
1357
|
+
fn init(ruby: &magnus::Ruby) -> Result<(), Error> {
|
1358
|
+
SimpleLogger::new().env().init().unwrap();
|
1359
|
+
|
1360
|
+
ruby.get_inner(&OSPREY)
|
1361
|
+
.define_singleton_method("start", function!(start_server, -1))?;
|
1362
|
+
|
1363
|
+
define_collector(ruby)?;
|
1364
|
+
define_request_info(&ruby)?;
|
1365
|
+
|
1366
|
+
Ok(())
|
1367
|
+
}
|
1368
|
+
|
1369
|
+
// Mapping functions for requests and responses
|
1370
|
+
#[inline]
|
1371
|
+
async fn map_hyper_request(
|
1372
|
+
parts: hyper::http::request::Parts,
|
1373
|
+
peer_addr: SocketAddr,
|
1374
|
+
script_name: String,
|
1375
|
+
stream_fd: i32,
|
1376
|
+
) -> (RackRequest, oneshot::Receiver<RackResponse>) {
|
1377
|
+
let (sender, receiver) = oneshot::channel();
|
1378
|
+
|
1379
|
+
let headers: HashMap<String, String> = parts
|
1380
|
+
.headers
|
1381
|
+
.iter()
|
1382
|
+
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
|
1383
|
+
.collect();
|
1384
|
+
|
1385
|
+
let protocol = headers
|
1386
|
+
.get("upgrade")
|
1387
|
+
.and_then(|value| {
|
1388
|
+
Some(
|
1389
|
+
value
|
1390
|
+
.split(',')
|
1391
|
+
.map(|s| s.trim().to_string())
|
1392
|
+
.collect::<Vec<_>>(),
|
1393
|
+
)
|
1394
|
+
})
|
1395
|
+
.or_else(|| {
|
1396
|
+
headers.get("protocol").map(|value| {
|
1397
|
+
value
|
1398
|
+
.split(',')
|
1399
|
+
.map(|s| s.trim().to_string())
|
1400
|
+
.collect::<Vec<_>>()
|
1401
|
+
})
|
1402
|
+
})
|
1403
|
+
.unwrap_or_else(|| vec!["http".to_string()]);
|
1404
|
+
|
1405
|
+
unsafe {
|
1406
|
+
debug!("(PID={}) peer_adr: {:?}", libc::getpid(), peer_addr);
|
1407
|
+
debug!(
|
1408
|
+
"(PID={}) host header: {:?}",
|
1409
|
+
libc::getpid(),
|
1410
|
+
headers.get("host").unwrap_or(&"".to_string())
|
1411
|
+
);
|
1412
|
+
};
|
1413
|
+
let scheme = parts.uri.scheme_str().unwrap_or("http").to_string();
|
1414
|
+
let host = parts
|
1415
|
+
.uri
|
1416
|
+
.host()
|
1417
|
+
.unwrap_or(headers.get("host").unwrap_or(&"".to_string()))
|
1418
|
+
.to_string();
|
1419
|
+
let port = host
|
1420
|
+
.split(':')
|
1421
|
+
.nth(1)
|
1422
|
+
.unwrap_or(if &scheme == "https" { "443" } else { "80" })
|
1423
|
+
.parse::<i32>()
|
1424
|
+
.unwrap();
|
1425
|
+
|
1426
|
+
let query_string = parts.uri.query().unwrap_or("").to_string();
|
1427
|
+
(
|
1428
|
+
RackRequest(RefCell::new(Some(RackRequestInner {
|
1429
|
+
path: parts.uri.path().to_string(),
|
1430
|
+
method: parts.method.to_string(),
|
1431
|
+
version: format!("{:?}", parts.version),
|
1432
|
+
remote_addr: peer_addr.ip().to_string(),
|
1433
|
+
script_name,
|
1434
|
+
scheme,
|
1435
|
+
host,
|
1436
|
+
port,
|
1437
|
+
protocol,
|
1438
|
+
query_string,
|
1439
|
+
headers,
|
1440
|
+
stream_fd,
|
1441
|
+
sender: Arc::new(Mutex::new(Some(sender))),
|
1442
|
+
}))),
|
1443
|
+
receiver,
|
1444
|
+
)
|
1445
|
+
}
|
1446
|
+
|
1447
|
+
use hyper::body::{Frame, SizeHint};
|
1448
|
+
use std::task::{Context, Poll};
|
1449
|
+
use tokio::{
|
1450
|
+
io::{AsyncBufReadExt, AsyncReadExt, BufReader},
|
1451
|
+
sync::mpsc,
|
1452
|
+
};
|
1453
|
+
|
1454
|
+
#[derive(Debug)]
|
1455
|
+
pub struct HijackedStreamingBody {
|
1456
|
+
rx: mpsc::Receiver<Result<Bytes, io::Error>>,
|
1457
|
+
done: bool,
|
1458
|
+
}
|
1459
|
+
|
1460
|
+
impl hyper::body::Body for HijackedStreamingBody {
|
1461
|
+
type Data = Bytes;
|
1462
|
+
type Error = Infallible;
|
1463
|
+
|
1464
|
+
fn poll_frame(
|
1465
|
+
mut self: Pin<&mut Self>,
|
1466
|
+
cx: &mut Context<'_>,
|
1467
|
+
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
|
1468
|
+
if self.done {
|
1469
|
+
return Poll::Ready(None);
|
1470
|
+
}
|
1471
|
+
|
1472
|
+
match Pin::new(&mut self.rx).poll_recv(cx) {
|
1473
|
+
Poll::Ready(Some(Ok(bytes))) => {
|
1474
|
+
let frame = Frame::data(bytes);
|
1475
|
+
Poll::Ready(Some(Ok(frame)))
|
1476
|
+
}
|
1477
|
+
Poll::Ready(Some(Err(e))) => {
|
1478
|
+
self.done = true;
|
1479
|
+
Poll::Ready(None)
|
1480
|
+
}
|
1481
|
+
Poll::Ready(None) => Poll::Ready(None),
|
1482
|
+
Poll::Pending => Poll::Pending,
|
1483
|
+
}
|
1484
|
+
}
|
1485
|
+
fn is_end_stream(&self) -> bool {
|
1486
|
+
false // rely on the channel closing
|
1487
|
+
}
|
1488
|
+
|
1489
|
+
fn size_hint(&self) -> SizeHint {
|
1490
|
+
SizeHint::default()
|
1491
|
+
}
|
1492
|
+
}
|
1493
|
+
|
1494
|
+
// ─── A HELPER TO PARSE THE HTTP/1.1 RESPONSE ───────────────────────────────
|
1495
|
+
|
1496
|
+
// Given a buffered reader reading from the hijacked connection, read until the header section
|
1497
|
+
// is complete (i.e. until "\r\n\r\n" is encountered), and then return the header lines and any
|
1498
|
+
// bytes that already belong to the body.
|
1499
|
+
async fn read_response_headers<R>(reader: &mut R) -> io::Result<(Vec<String>, Vec<u8>)>
|
1500
|
+
where
|
1501
|
+
R: AsyncBufReadExt + Unpin,
|
1502
|
+
{
|
1503
|
+
let mut headers_buf = Vec::new();
|
1504
|
+
// Read until we have the header terminator.
|
1505
|
+
loop {
|
1506
|
+
let mut line = String::new();
|
1507
|
+
let n = reader.read_line(&mut line).await?;
|
1508
|
+
if n == 0 {
|
1509
|
+
// EOF reached unexpectedly.
|
1510
|
+
break;
|
1511
|
+
}
|
1512
|
+
headers_buf.extend_from_slice(line.as_bytes());
|
1513
|
+
if headers_buf.ends_with(b"\r\n\r\n") || headers_buf.ends_with(b"\n\n") {
|
1514
|
+
break;
|
1515
|
+
}
|
1516
|
+
}
|
1517
|
+
// Split the headers from any already-read part of the body.
|
1518
|
+
let header_body_split = b"\r\n\r\n";
|
1519
|
+
if let Some(pos) = headers_buf
|
1520
|
+
.windows(header_body_split.len())
|
1521
|
+
.position(|window| window == header_body_split)
|
1522
|
+
{
|
1523
|
+
let header_lines = &headers_buf[..pos];
|
1524
|
+
let remaining = headers_buf[(pos + header_body_split.len())..].to_vec();
|
1525
|
+
// Split the header lines into strings.
|
1526
|
+
let headers = String::from_utf8_lossy(header_lines)
|
1527
|
+
.lines()
|
1528
|
+
.map(|s| s.to_string())
|
1529
|
+
.collect::<Vec<_>>();
|
1530
|
+
Ok((headers, remaining))
|
1531
|
+
} else {
|
1532
|
+
// If not found, return all as headers and no body data.
|
1533
|
+
let headers = String::from_utf8_lossy(&headers_buf)
|
1534
|
+
.lines()
|
1535
|
+
.map(|s| s.to_string())
|
1536
|
+
.collect::<Vec<_>>();
|
1537
|
+
Ok((headers, Vec::new()))
|
1538
|
+
}
|
1539
|
+
}
|
1540
|
+
|
1541
|
+
/// Given the raw header lines (from an HTTP/1.1 response), parse out the status and headers.
|
1542
|
+
///
|
1543
|
+
/// This is a very simplistic parser; you may wish to use a real HTTP parser.
|
1544
|
+
fn parse_status_and_headers(lines: &[String]) -> (StatusCode, HeaderMap) {
|
1545
|
+
let mut status = StatusCode::OK;
|
1546
|
+
let mut header_map = HeaderMap::new();
|
1547
|
+
|
1548
|
+
if let Some(first_line) = lines.first() {
|
1549
|
+
// Expect a line like "HTTP/1.1 200 OK"
|
1550
|
+
let parts: Vec<&str> = first_line.split_whitespace().collect();
|
1551
|
+
if parts.len() >= 2 {
|
1552
|
+
if let Ok(code) = parts[1].parse::<u16>() {
|
1553
|
+
status = StatusCode::from_u16(code).unwrap_or(StatusCode::OK);
|
1554
|
+
}
|
1555
|
+
}
|
1556
|
+
}
|
1557
|
+
for line in &lines[1..] {
|
1558
|
+
if let Some((name, value)) = line.split_once(':') {
|
1559
|
+
header_map.insert(
|
1560
|
+
HeaderName::from_bytes(name.as_bytes()).unwrap(),
|
1561
|
+
HeaderValue::from_bytes(value.as_bytes()).unwrap(),
|
1562
|
+
);
|
1563
|
+
}
|
1564
|
+
}
|
1565
|
+
(status, header_map)
|
1566
|
+
}
|
1567
|
+
// ─── THE HIJACK HANDLING FUNCTION ───────────────────────────────────────────
|
1568
|
+
|
1569
|
+
async fn read_hijacked_headers(fd: i32) -> (HeaderMap, StatusCode, bool, Vec<u8>) {
|
1570
|
+
let hijack_stream =
|
1571
|
+
tokio::net::UnixStream::from_std(unsafe { UnixStream::from_raw_fd(fd) }).unwrap();
|
1572
|
+
let mut reader = BufReader::new(hijack_stream);
|
1573
|
+
|
1574
|
+
// Read until blank line => that's the "HTTP/1.1 101\r\nUpg...\r\n\r\n"
|
1575
|
+
let (header_lines, initial_body_bytes) = read_response_headers(&mut reader).await.unwrap();
|
1576
|
+
|
1577
|
+
// Parse them
|
1578
|
+
let (status, hdrs) = parse_status_and_headers(&header_lines);
|
1579
|
+
let requires_upgrade = status == StatusCode::SWITCHING_PROTOCOLS;
|
1580
|
+
|
1581
|
+
// Dont close the FD, we're going to use it later.
|
1582
|
+
let _ = reader.into_inner().into_std().unwrap().into_raw_fd();
|
1583
|
+
|
1584
|
+
(hdrs, status, requires_upgrade, initial_body_bytes)
|
1585
|
+
}
|
1586
|
+
|
1587
|
+
/// This function is called when rack_response.status == -1 (i.e. the app hijacked the stream).
|
1588
|
+
/// It takes the raw UnixStream (from the proxy side) and uses it to parse an HTTP/1.1 response
|
1589
|
+
/// written by the app. It then returns a Hyper Response that uses a streaming body.
|
1590
|
+
/// This function is called when `rack_response.status == -1`.
|
1591
|
+
/// We read the lines Ruby wrote over that hijacked FD, parse them,
|
1592
|
+
/// and build a streaming response. But if Ruby’s lines say `101`,
|
1593
|
+
/// we do not close the connection until Ruby closes its socket.
|
1594
|
+
async fn handle_hijacked_response(
|
1595
|
+
hijack_stream_fd: i32,
|
1596
|
+
parts: hyper::http::request::Parts,
|
1597
|
+
) -> Result<Response<BoxBody<Bytes, Infallible>>, hyper::Error> {
|
1598
|
+
// We read the headers from the raw stream (in HTTP/1.1 format)
|
1599
|
+
// and will either upgrade the connection or use the body as a streaming body.
|
1600
|
+
let (headers, status, requires_upgrade, initial_body_bytes) =
|
1601
|
+
read_hijacked_headers(hijack_stream_fd).await;
|
1602
|
+
|
1603
|
+
let hijack_stream =
|
1604
|
+
tokio::net::UnixStream::from_std(unsafe { UnixStream::from_raw_fd(hijack_stream_fd) })
|
1605
|
+
.unwrap();
|
1606
|
+
|
1607
|
+
let mut response = if requires_upgrade {
|
1608
|
+
tokio::spawn(async move {
|
1609
|
+
let mut req = Request::from_parts(parts, Empty::<Bytes>::new());
|
1610
|
+
match hyper::upgrade::on(&mut req).await {
|
1611
|
+
Ok(upgraded) => {
|
1612
|
+
two_way_bridge(upgraded, hijack_stream)
|
1613
|
+
.await
|
1614
|
+
.expect("Error in creating two way bridge");
|
1615
|
+
}
|
1616
|
+
Err(e) => eprintln!("upgrade error: {:?}", e),
|
1617
|
+
}
|
1618
|
+
});
|
1619
|
+
Response::new(BoxBody::new(Empty::new()))
|
1620
|
+
} else {
|
1621
|
+
let (tx, rx) = mpsc::channel::<Result<Bytes, io::Error>>(16);
|
1622
|
+
if !initial_body_bytes.is_empty() {
|
1623
|
+
let _ = tx.send(Ok(Bytes::from(initial_body_bytes))).await;
|
1624
|
+
}
|
1625
|
+
|
1626
|
+
tokio::spawn(async move {
|
1627
|
+
let mut reader = BufReader::new(hijack_stream);
|
1628
|
+
let mut buf = [0u8; 1024];
|
1629
|
+
loop {
|
1630
|
+
match reader.read(&mut buf).await {
|
1631
|
+
Ok(0) => {
|
1632
|
+
info!("Hijacked socket closed. EOF");
|
1633
|
+
break;
|
1634
|
+
}
|
1635
|
+
Ok(n) => {
|
1636
|
+
// Forward chunk
|
1637
|
+
if tx
|
1638
|
+
.send(Ok(Bytes::copy_from_slice(&buf[..n])))
|
1639
|
+
.await
|
1640
|
+
.is_err()
|
1641
|
+
{
|
1642
|
+
info!("Client side dropped the channel");
|
1643
|
+
break;
|
1644
|
+
}
|
1645
|
+
}
|
1646
|
+
Err(e) => {
|
1647
|
+
info!("Read error: {:?}", e);
|
1648
|
+
let _ = tx.send(Err(e)).await;
|
1649
|
+
break;
|
1650
|
+
}
|
1651
|
+
}
|
1652
|
+
}
|
1653
|
+
});
|
1654
|
+
Response::new(BoxBody::new(HijackedStreamingBody { rx, done: false }))
|
1655
|
+
};
|
1656
|
+
*response.status_mut() = status;
|
1657
|
+
*response.headers_mut() = headers;
|
1658
|
+
Ok(response)
|
1659
|
+
}
|