osprey 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,1659 @@
1
+ #![deny(unused_crate_dependencies)]
2
+
3
+ use bytes::Bytes;
4
+ use crossbeam::channel::{Receiver, Sender};
5
+ use http_body_util::combinators::BoxBody;
6
+ use http_body_util::{BodyExt, Empty};
7
+ use hyper::body::{Body, Incoming};
8
+ use hyper::header::{HeaderName, HeaderValue};
9
+ use hyper::service::service_fn;
10
+ use hyper::upgrade::Upgraded;
11
+ use hyper::{HeaderMap, Request, Response, StatusCode};
12
+ use hyper_util::rt::{TokioExecutor, TokioIo};
13
+ use hyper_util::server::conn::auto::Builder;
14
+ use libc::{sighandler_t, sleep};
15
+ use log::{debug, error, info, warn};
16
+ use magnus::value::{Id, Lazy, LazyId, ReprValue};
17
+ use magnus::{
18
+ block::Proc,
19
+ exception, function,
20
+ scan_args::{get_kwargs, scan_args, Args, KwArgs},
21
+ Error, Object, RHash, Value,
22
+ };
23
+ use magnus::{eval, Class, Module, RModule, Ruby};
24
+ use magnus::{RArray, RClass};
25
+ use nix::cmsg_space;
26
+ use nix::sys::socket::{
27
+ recvmsg, sendmsg, socketpair, AddressFamily, ControlMessage, ControlMessageOwned, MsgFlags,
28
+ SockFlag, SockType,
29
+ };
30
+ use std::cell::RefCell;
31
+ use std::convert::Infallible;
32
+ use std::os::unix::io::RawFd;
33
+ use std::os::unix::net::UnixStream;
34
+ use std::ptr::null_mut;
35
+ use structs::rack_request::{define_request_info, RackRequest, RackRequestInner};
36
+ use tokio::io::unix::AsyncFd;
37
+ use tokio::io::{AsyncRead, AsyncWrite};
38
+ use tokio::sync::broadcast::channel;
39
+ use tokio_stream::StreamExt;
40
+
41
+ use rb_sys::{rb_thread_call_with_gvl, rb_thread_call_without_gvl, rb_thread_create};
42
+ use simple_logger::SimpleLogger;
43
+ use socket2::{Domain, Protocol, Socket, Type};
44
+ use std::io::{IoSlice, IoSliceMut};
45
+ use std::net::{TcpListener, TcpStream};
46
+ use std::os::fd::{FromRawFd, IntoRawFd, OwnedFd};
47
+ use std::panic;
48
+ use std::pin::Pin;
49
+ use std::sync::atomic::Ordering;
50
+ use std::sync::atomic::{AtomicBool, AtomicUsize};
51
+ use std::sync::{Mutex, RwLock};
52
+ use std::{
53
+ collections::HashMap, ffi::c_void, fs::OpenOptions, net::SocketAddr, os::fd::AsRawFd,
54
+ sync::Arc, time::Duration,
55
+ };
56
+ use structs::rack_response::{define_collector, RackResponse};
57
+ use tokio::net::{TcpListener as TokioTcpListener, TcpStream as TokioTcpStream};
58
+ use tokio::runtime::Builder as RuntimeBuilder;
59
+ use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
60
+ use tokio::sync::oneshot::{self};
61
+ use tokio::task::{JoinError, JoinHandle};
62
+ use tokio::time::sleep as TokioSleep;
63
+ use tokio_rustls::rustls::ServerConfig;
64
+ use tokio_rustls::TlsAcceptor;
65
+
66
+ mod structs;
67
+
68
+ trait IoStream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {}
69
+ impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin> IoStream for T {}
70
+
71
+ /// --------------------------------------------------
72
+ /// Global/Static shared data for forking & shutdown
73
+ /// --------------------------------------------------
74
+
75
+ static CHILDREN: RwLock<Vec<ChildProcess>> = RwLock::new(vec![]);
76
+
77
+ static SHUTDOWN_REQUESTED: AtomicBool = AtomicBool::new(false);
78
+ static CTRL_C_SIGNAL_CHANNEL: RwLock<Option<tokio::sync::broadcast::Sender<()>>> =
79
+ RwLock::new(None);
80
+
81
+ pub static OSPREY: Lazy<RModule> = Lazy::new(|ruby| ruby.define_module("Osprey").unwrap());
82
+ pub static ID_SCHEDULE: LazyId = LazyId::new("schedule");
83
+ pub static ID_CALL: LazyId = LazyId::new("call");
84
+ pub static ID_NEW: LazyId = LazyId::new("new");
85
+ pub static ID_PUSH: LazyId = LazyId::new("push");
86
+ pub static ID_RESUME: LazyId = LazyId::new("resume");
87
+ pub static ID_TRANSFORM_VALUES: LazyId = LazyId::new("transform_values");
88
+ pub static ARRAY_PROC: Lazy<Proc> = Lazy::new(|ruby| {
89
+ ruby.module_kernel()
90
+ .funcall::<&str, (&str,), Value>("method", ("Array",))
91
+ .unwrap()
92
+ .funcall("to_proc", ())
93
+ .unwrap()
94
+ });
95
+ pub static WORK_QUEUE: Lazy<Value> = Lazy::new(|ruby| {
96
+ let queue = ruby
97
+ .define_class("Queue", ruby.class_object())
98
+ .unwrap()
99
+ .funcall::<&str, (), Value>("new", ())
100
+ .unwrap();
101
+ ruby.get_inner(&OSPREY)
102
+ .ivar_set("@work_queue", queue)
103
+ .unwrap();
104
+ queue
105
+ });
106
+
107
+ static SCHEDULER_CLASS: Lazy<RClass> = Lazy::new(|ruby| {
108
+ ruby.define_module("Async")
109
+ .unwrap()
110
+ .const_get("Scheduler")
111
+ .unwrap()
112
+ });
113
+
114
+ static FIBER_CLASS: Lazy<RClass> =
115
+ Lazy::new(|ruby| ruby.define_class("Fiber", ruby.class_object()).unwrap());
116
+ /// A simple struct to hold the server config + the Ruby app `Proc`
117
+ #[derive(Clone)]
118
+ struct RackServerConfig {
119
+ host: String,
120
+ port: u16,
121
+ http_port: Option<u16>,
122
+ workers: u16,
123
+ threads: u16,
124
+ cert_path: Option<String>,
125
+ key_path: Option<String>,
126
+ shutdown_timeout: Duration,
127
+ script_name: String,
128
+ before_fork: Option<Proc>,
129
+ after_fork: Option<Proc>,
130
+ }
131
+
132
+ #[derive(Debug, Clone)]
133
+ struct ChildProcess {
134
+ pid: i32,
135
+ writer: Arc<OwnedFd>,
136
+ }
137
+
138
+ struct AcceptLoopArgs {
139
+ http_listener: Option<TcpListener>,
140
+ https_listener: Option<TcpListener>,
141
+ shutdown_timeout: Duration,
142
+ threads: u16,
143
+ reader: Option<OwnedFd>,
144
+ script_name: String,
145
+ cert_path: Option<String>,
146
+ key_path: Option<String>,
147
+ }
148
+
149
+ unsafe extern "C" fn accept_loop_interrupt(_ptr: *mut c_void) {
150
+ info!("(PID={}) Interrupt accept loop", libc::getpid());
151
+ }
152
+
153
+ unsafe fn setup_listeners(
154
+ server: &RackServerConfig,
155
+ ) -> Result<(Option<TcpListener>, Option<TcpListener>), Error> {
156
+ let mut http_listener = None;
157
+ let mut https_listener = None;
158
+
159
+ // Set up HTTP listener if http_port is provided
160
+ if let Some(http_port) = server.http_port {
161
+ let socket = Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP)).map_err(|e| {
162
+ Error::new(
163
+ exception::io_error(),
164
+ format!("Socket creation failed: {}", e),
165
+ )
166
+ })?;
167
+ socket.set_reuse_address(true).ok();
168
+ socket.set_reuse_port(true).ok();
169
+ socket.set_nonblocking(true).ok();
170
+ socket.set_nodelay(true).ok();
171
+ socket.set_recv_buffer_size(1048576).ok();
172
+
173
+ let address: SocketAddr =
174
+ format!("{}:{}", server.host, http_port)
175
+ .parse()
176
+ .map_err(|e| {
177
+ Error::new(
178
+ exception::io_error(),
179
+ format!(
180
+ "Failed to parse address {}: {}",
181
+ format!("{}:{}", server.host, http_port),
182
+ e
183
+ ),
184
+ )
185
+ })?;
186
+ socket.bind(&address.into()).map_err(|e| {
187
+ Error::new(
188
+ exception::io_error(),
189
+ format!("bind failed for HTTP: {}", e),
190
+ )
191
+ })?;
192
+ socket.listen(1024).map_err(|e| {
193
+ Error::new(
194
+ exception::io_error(),
195
+ format!("listen failed for HTTP: {}", e),
196
+ )
197
+ })?;
198
+ let listener: TcpListener = socket.into();
199
+ listener.set_nonblocking(true).map_err(|e| {
200
+ Error::new(
201
+ exception::io_error(),
202
+ format!("set_nonblocking failed for HTTP: {}", e),
203
+ )
204
+ })?;
205
+ info!("(PID={}) Listening on http://{}", libc::getpid(), address);
206
+ http_listener = Some(listener);
207
+ }
208
+
209
+ // Set up HTTPS listener if cert_path and key_path are provided
210
+ if server.cert_path.is_some() && server.key_path.is_some() {
211
+ let socket = Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP)).map_err(|e| {
212
+ Error::new(
213
+ exception::io_error(),
214
+ format!("Socket creation failed: {}", e),
215
+ )
216
+ })?;
217
+ socket.set_reuse_address(true).ok();
218
+ socket.set_reuse_port(true).ok();
219
+ socket.set_nonblocking(true).ok();
220
+ socket.set_nodelay(true).ok();
221
+ socket.set_recv_buffer_size(1048576).ok();
222
+
223
+ let address: SocketAddr =
224
+ format!("{}:{}", server.host, server.port)
225
+ .parse()
226
+ .map_err(|e| {
227
+ Error::new(
228
+ exception::io_error(),
229
+ format!(
230
+ "Failed to parse address {}: {}",
231
+ format!("{}:{}", server.host, server.port),
232
+ e
233
+ ),
234
+ )
235
+ })?;
236
+ socket.bind(&address.into()).map_err(|e| {
237
+ Error::new(
238
+ exception::io_error(),
239
+ format!("bind failed for HTTPS: {}", e),
240
+ )
241
+ })?;
242
+ socket.listen(1024).map_err(|e| {
243
+ Error::new(
244
+ exception::io_error(),
245
+ format!("listen failed for HTTPS: {}", e),
246
+ )
247
+ })?;
248
+ let listener: TcpListener = socket.into();
249
+ listener.set_nonblocking(true).map_err(|e| {
250
+ Error::new(
251
+ exception::io_error(),
252
+ format!("set_nonblocking failed for HTTP: {}", e),
253
+ )
254
+ })?;
255
+ info!("(PID={}) Listening on https://{}", libc::getpid(), address);
256
+ https_listener = Some(listener);
257
+ }
258
+
259
+ Ok((http_listener, https_listener))
260
+ }
261
+
262
+ async fn await_fd(fd: RawFd) -> Result<(), JoinError> {
263
+ let async_fd_result = AsyncFd::new(fd);
264
+
265
+ if let Ok(async_fd) = async_fd_result {
266
+ return tokio::spawn(async move {
267
+ loop {
268
+ match async_fd.readable().await {
269
+ Ok(guard) => {
270
+ if guard.ready().is_read_closed() {
271
+ break;
272
+ } else {
273
+ let _ = TokioSleep(Duration::from_millis(1000)).await;
274
+ }
275
+ }
276
+ Err(e) => {
277
+ println!("Error waiting for readiness: {:?}", e);
278
+ break;
279
+ }
280
+ }
281
+ }
282
+ })
283
+ .await;
284
+ }
285
+ Ok(())
286
+ }
287
+
288
+ async unsafe fn accept_loop_async(serve_args: &mut AcceptLoopArgs) -> Result<(), Error> {
289
+ let tls_listener = serve_args.https_listener.as_ref().map(|listener| {
290
+ TokioTcpListener::from_std(listener.try_clone().expect("Clone STD Listener"))
291
+ .expect("Construct Tokio Listener from STD")
292
+ });
293
+ let non_tls_listener = serve_args.http_listener.as_ref().map(|listener| {
294
+ TokioTcpListener::from_std(listener.try_clone().expect("Clone STD Listener"))
295
+ .expect("Construct Tokio Listener from STD")
296
+ });
297
+
298
+ let (signal_tx, mut signal_rx) = channel(1);
299
+ let (shutdown_tx, mut shutdown_rx) = channel(1);
300
+ *CTRL_C_SIGNAL_CHANNEL.write().unwrap() = Some(signal_tx);
301
+ let children = CHILDREN.read().unwrap().to_vec();
302
+ let mut child_cycle = children.iter().cycle();
303
+
304
+ tokio::spawn(async move {
305
+ signal_rx.recv().await.ok();
306
+ info!(
307
+ "(PID={}) Signal received. Initiating shutdown",
308
+ libc::getpid()
309
+ );
310
+ shutdown_tx.send(()).ok();
311
+ });
312
+
313
+ loop {
314
+ tokio::select! {
315
+ conn = async { if let Some(listener) = tls_listener.as_ref() { Some(listener.accept().await) } else { None } }, if tls_listener.is_some() => {
316
+ if let Some(Ok((stream, _peer_addr))) = conn {
317
+ let fd = stream.into_std().expect("Failed to convert to std FD").into_raw_fd();
318
+
319
+ let iov = [IoSlice::new(b"1")];
320
+ let next_child = child_cycle.next().expect("Next child process");
321
+ sendmsg::<()>(
322
+ next_child.writer.as_raw_fd() as i32,
323
+ &iov,
324
+ &[ControlMessage::ScmRights(&[fd])], // Send the file descriptor
325
+ MsgFlags::empty(),
326
+ None, // No address
327
+ ).expect("Failed to send file descriptor");
328
+
329
+ if SHUTDOWN_REQUESTED.load(Ordering::Relaxed) {
330
+ break;
331
+ }
332
+ }
333
+ },
334
+ conn = async { if let Some(listener) = non_tls_listener.as_ref() { Some(listener.accept().await) } else { None } }, if non_tls_listener.is_some() => {
335
+ if let Some(Ok((stream, _peer_addr))) = conn {
336
+ let fd = stream.into_std().expect("Failed to convert to std FD").into_raw_fd();
337
+ let iov = [IoSlice::new(b"0")];
338
+ let next_child = child_cycle.next().expect("Next child process");
339
+ sendmsg::<()>(
340
+ next_child.writer.as_raw_fd() as i32,
341
+ &iov,
342
+ &[ControlMessage::ScmRights(&[fd])], // Send the file descriptor
343
+ MsgFlags::empty(),
344
+ None, // No address
345
+ ).expect("Failed to send file descriptor");
346
+ if SHUTDOWN_REQUESTED.load(Ordering::Relaxed) {
347
+ break;
348
+ }
349
+ }
350
+ }
351
+ _ = shutdown_rx.recv() => {
352
+ info!("(PID={}) Shutdown signal received. Breaking out of accept loop", libc::getpid());
353
+ break;
354
+ }
355
+ }
356
+ }
357
+
358
+ info!(
359
+ "(PID={}) Waiting up to {:?} for in-flight requests to complete",
360
+ libc::getpid(),
361
+ serve_args.shutdown_timeout
362
+ );
363
+
364
+ debug!("(PID={}) Dropping listeners", libc::getpid());
365
+ drop(tls_listener);
366
+ drop(non_tls_listener);
367
+ initiate_shutdown(0, 0);
368
+
369
+ Ok(())
370
+ }
371
+
372
+ use std::io;
373
+ use tokio::{io::AsyncWriteExt, net::UnixStream as TokioUnixStream};
374
+
375
+ async fn bridge_with_logging(
376
+ mut from: (impl AsyncRead + Unpin),
377
+ mut to: (impl AsyncWrite + Unpin),
378
+ direction_label: &str,
379
+ ) -> io::Result<u64> {
380
+ use tokio::io::{AsyncReadExt, AsyncWriteExt};
381
+ let mut buf = [0u8; 4096];
382
+ let mut total_bytes = 0u64;
383
+
384
+ loop {
385
+ let n = from.read(&mut buf).await?;
386
+ if n == 0 {
387
+ // EOF => no more data
388
+ break;
389
+ }
390
+ println!(
391
+ "[bridge {}] Read {} bytes: {:?}",
392
+ direction_label,
393
+ n,
394
+ &buf[..n]
395
+ );
396
+
397
+ to.write_all(&buf[..n]).await?;
398
+ total_bytes += n as u64;
399
+ }
400
+ Ok(total_bytes)
401
+ }
402
+
403
+ async fn two_way_bridge(upgraded: Upgraded, local: TokioUnixStream) -> io::Result<()> {
404
+ // Wrap the `Upgraded` with `TokioIo` so it implements AsyncRead+AsyncWrite
405
+ let client_io = TokioIo::new(upgraded);
406
+
407
+ // Split each side
408
+ let (mut lr, mut lw) = tokio::io::split(local);
409
+ let (mut cr, mut cw) = tokio::io::split(client_io);
410
+
411
+ let to_ruby = tokio::spawn(async move {
412
+ if let Err(e) = tokio::io::copy(&mut cr, &mut lw).await {
413
+ eprintln!("Error copying upgraded->local: {:?}", e);
414
+ }
415
+ });
416
+ let from_ruby = tokio::spawn(async move {
417
+ if let Err(e) = tokio::io::copy(&mut lr, &mut cw).await {
418
+ eprintln!("Error copying upgraded->local: {:?}", e);
419
+ }
420
+ });
421
+
422
+ let _ = to_ruby.await;
423
+ let _ = from_ruby.await;
424
+ Ok(())
425
+ }
426
+
427
+ async unsafe fn handle_hyper_request(
428
+ req: Request<Incoming>,
429
+ peer_addr: SocketAddr,
430
+ sender: Sender<RackRequest>,
431
+ script_name: String,
432
+ ) -> Result<Response<BoxBody<Bytes, Infallible>>, &'static str> {
433
+ let (parts, body) = req.into_parts();
434
+
435
+ let has_body = !body.is_end_stream();
436
+
437
+ // 1. (Optional) If there's a body, we'll expose it as an IO into our Ruby process.
438
+ let (local_fd, remote_fd) = if has_body {
439
+ let (local, remote) = TokioUnixStream::pair().map_err(|_| "socketpair failed")?;
440
+ let local_fd = local.as_raw_fd();
441
+ let (_lr, mut lw) = tokio::io::split(local);
442
+ let mut data_stream = body.into_data_stream();
443
+
444
+ // Stream all incoming data into our remote FD (Opened inside Ruby as an IO)
445
+ tokio::spawn(async move {
446
+ while let Some(Ok(chunk)) = data_stream.next().await {
447
+ info!("Got write chunk ! {:?}", chunk);
448
+ if lw.write_all(&chunk).await.is_err() {
449
+ break;
450
+ }
451
+ }
452
+ info!("Relinquishing ownership of local_fd");
453
+ _lr.unsplit(lw).into_std().unwrap().into_raw_fd()
454
+ });
455
+
456
+ let remote_fd = remote.into_std().unwrap().into_raw_fd();
457
+ (Some(local_fd), remote_fd)
458
+ } else {
459
+ (None, -1)
460
+ };
461
+
462
+ // 2) Build and send a RackRequest to our Rack app.
463
+ let (rack_req, resp_rx) =
464
+ map_hyper_request(parts.clone(), peer_addr, script_name, remote_fd).await;
465
+ sender.send(rack_req).expect("Failed to send to Ruby");
466
+
467
+ // 3) Then wait for our Rack app to return a RackResponse
468
+ let rack_response = resp_rx.await.map_err(|_| "Ruby channel closed")?;
469
+
470
+ // 4) We return a negative status if the response was hijacked. We'll reuse the body pipe
471
+ // if it already exists, otherwise, we can count on the hijacked response to have created a new one.
472
+ if rack_response.status < 0 {
473
+ let local_fd = match local_fd {
474
+ Some(ld) => ld,
475
+ None => (rack_response.status * -1) as i32,
476
+ };
477
+
478
+ return match handle_hijacked_response(local_fd, parts).await {
479
+ Ok(resp) => Ok(resp),
480
+ Err(_) => Err("Error handling hijacked response"),
481
+ };
482
+ }
483
+
484
+ // 5) Otherwise, normal HTTP response
485
+ Ok(rack_response.into_hyper_response())
486
+ }
487
+
488
+ async fn handle_connection(
489
+ input_stream: Pin<Box<tokio::net::TcpStream>>,
490
+ tls_acceptor: Option<TlsAcceptor>,
491
+ server: Arc<Builder<TokioExecutor>>,
492
+ peer_addr: SocketAddr,
493
+ sender: Sender<RackRequest>,
494
+ script_name: String,
495
+ ) {
496
+ unsafe {
497
+ debug!("(PID={}) Handling connection", libc::getpid());
498
+ }
499
+ let stream: TokioIo<_> = match &tls_acceptor {
500
+ Some(acceptor) => match acceptor.accept(input_stream).await {
501
+ Ok(tls_stream) => TokioIo::new(Box::pin(tls_stream) as Pin<Box<dyn IoStream>>),
502
+ Err(err) => {
503
+ warn!("TLS handshake error: {}", err);
504
+ return;
505
+ }
506
+ },
507
+ None => TokioIo::new(Box::pin(input_stream) as Pin<Box<dyn IoStream>>),
508
+ };
509
+
510
+ unsafe {
511
+ debug!("(PID={}) serve connection with upgrades", libc::getpid());
512
+ }
513
+
514
+ if let Err(err) = server
515
+ .serve_connection_with_upgrades(
516
+ stream,
517
+ service_fn(move |hyper_request| unsafe {
518
+ handle_hyper_request(
519
+ hyper_request,
520
+ peer_addr,
521
+ sender.clone(),
522
+ script_name.clone(),
523
+ )
524
+ }),
525
+ )
526
+ .await
527
+ {
528
+ debug!("Connection aborted: {:?}", err);
529
+ }
530
+
531
+ unsafe { debug!("(PID={}) Connection closed", libc::getpid()) };
532
+ }
533
+
534
+ async unsafe fn accept_ruby_connection_async(
535
+ serve_args: &mut AcceptLoopArgs,
536
+ sender: Sender<RackRequest>,
537
+ ) {
538
+ let reader = serve_args.reader.as_mut().unwrap();
539
+
540
+ drop(serve_args.http_listener.take());
541
+ drop(serve_args.https_listener.take());
542
+
543
+ let buf_raw = &mut [0u8; 1024];
544
+ let mut buf = [IoSliceMut::new(buf_raw)];
545
+ let server = Arc::new(Builder::new(TokioExecutor::new()));
546
+ let mut cmsg_space = cmsg_space!(Vec<RawFd>);
547
+
548
+ let (signal_tx, mut signal_rx) = channel(1);
549
+ let (shutdown_tx, mut shutdown_rx) = channel(1);
550
+ *CTRL_C_SIGNAL_CHANNEL.write().unwrap() = Some(signal_tx);
551
+
552
+ let mut current_requests: (
553
+ UnboundedSender<JoinHandle<()>>,
554
+ UnboundedReceiver<JoinHandle<()>>,
555
+ ) = tokio::sync::mpsc::unbounded_channel();
556
+
557
+ let in_flight_requests_handle = tokio::spawn(async move {
558
+ loop {
559
+ tokio::select! {
560
+ _ = signal_rx.recv() => {
561
+ info!("(PID={}) Signal received. Initiating shutdown", libc::getpid());
562
+ shutdown_tx.send(()).ok();
563
+ break
564
+ }
565
+ result = current_requests.1.recv() => {
566
+ match result {
567
+ Some(handle) => {
568
+ if let Err(e) = handle.await {
569
+ error!("Error handling request: {:?}", e);
570
+ }
571
+ },
572
+ _ => {
573
+ continue;
574
+ }
575
+ }
576
+ }
577
+ }
578
+ }
579
+ info!("(PID={}) Flushing final requests", libc::getpid());
580
+
581
+ info!(
582
+ "(PID={}) Awaiting in-flight request for clean shut down",
583
+ libc::getpid()
584
+ );
585
+ while let Ok(handle) = current_requests.1.try_recv() {
586
+ if let Err(e) = handle.await {
587
+ error!("Error handling request: {:?}", e);
588
+ }
589
+ }
590
+ });
591
+ let child_socket = Arc::new(AsyncFd::new(reader.as_raw_fd()).unwrap());
592
+ let tls_config = configure_tls(serve_args.cert_path.clone(), serve_args.key_path.clone());
593
+ let tls_acceptor = tls_config.map(TlsAcceptor::from);
594
+
595
+ info!("(PID={}) Starting main loop", libc::getpid());
596
+
597
+ loop {
598
+ tokio::select! {
599
+ guard = child_socket.readable() => {
600
+ match guard {
601
+ Ok(mut ready_guard) => {
602
+ while let Ok(msg) = recvmsg::<()>(
603
+ reader.as_raw_fd(),
604
+ &mut buf,
605
+ Some(&mut cmsg_space),
606
+ MsgFlags::MSG_DONTWAIT,
607
+ ) {
608
+ if msg.bytes == 0 {
609
+ info!("(PID={}) Received EOF. Exiting", libc::getpid());
610
+ return;
611
+ }
612
+
613
+ if let Some(ControlMessageOwned::ScmRights(fds)) = msg.cmsgs().unwrap().next() {
614
+ let fd: i32 = *fds.first().expect("File descriptor");
615
+ let tls: bool = buf[0][0] == '1' as u8;
616
+
617
+ let stream = TokioTcpStream::from_std(unsafe { TcpStream::from_raw_fd(fd) })
618
+ .expect("Tokio from std");
619
+
620
+
621
+ match stream.peer_addr(){
622
+ Ok(peer_addr) => {
623
+ let server_clone = server.clone();
624
+
625
+ let script_name_clone = serve_args.script_name.clone();
626
+ let tls_acceptor = if tls { tls_acceptor.clone() } else { None };
627
+ let sender_clone = sender.clone();
628
+ let handle = tokio::spawn(async move {
629
+ handle_connection(
630
+ Box::pin(stream),
631
+ tls_acceptor,
632
+ server_clone,
633
+ peer_addr,
634
+ sender_clone,
635
+ script_name_clone
636
+ )
637
+ .await;
638
+ });
639
+
640
+ info!("(PID={}) Spawned connection handler", libc::getpid());
641
+ if let Err(e) = current_requests.0.send(handle) {
642
+ error!("Failed to enqueue request: {:?}", e);
643
+ }
644
+ }
645
+ Err(e) => {
646
+ error!("Failed to get peer address: {:?}", e);
647
+ continue;
648
+ }
649
+ }
650
+
651
+ } else {
652
+ error!("No file descriptor received");
653
+ }
654
+ }
655
+
656
+ // Clear readiness state after processing
657
+ ready_guard.clear_ready();
658
+ }
659
+ Err(e) => {
660
+ error!("(PID={}) Failed to poll readability: {:?}", libc::getpid(), e);
661
+ }
662
+ }
663
+ }
664
+ _ = shutdown_rx.recv() => {
665
+ info!("(PID={}) Shutdown signal received. Breaking out of accept loop", libc::getpid());
666
+ break;
667
+ }
668
+ }
669
+ }
670
+
671
+ tokio::select! {
672
+ _ = in_flight_requests_handle => {
673
+ debug!("(PID={}) In-flight requests handled gracefully", libc::getpid());
674
+ }
675
+ _ = tokio::time::sleep(serve_args.shutdown_timeout) => {
676
+ debug!("(PID={}) Forced shutdown timeout reached", libc::getpid());
677
+ }
678
+ }
679
+
680
+ debug!("(PID={}) Dropping listeners", libc::getpid());
681
+ initiate_shutdown(0, 0);
682
+ }
683
+
684
+ pub unsafe fn call_without_gvl<F, R>(f: F) -> R
685
+ where
686
+ F: FnOnce() -> R,
687
+ {
688
+ // This is the function that Ruby calls “in the background” with the GVL released.
689
+ extern "C" fn trampoline<F, R>(arg: *mut c_void) -> *mut c_void
690
+ where
691
+ F: FnOnce() -> R,
692
+ {
693
+ // 1) Reconstruct the Box that holds our closure
694
+ let closure_ptr = arg as *mut Option<F>;
695
+ let closure = unsafe { (*closure_ptr).take().expect("Closure already taken") };
696
+
697
+ // 2) Call the user’s closure
698
+ let result = closure();
699
+
700
+ // 3) Box up the result so we can return a pointer to it
701
+ let boxed_result = Box::new(result);
702
+ Box::into_raw(boxed_result) as *mut c_void
703
+ }
704
+
705
+ // Box up the closure so we have a stable pointer
706
+ let mut closure_opt = Some(f);
707
+ let closure_ptr = &mut closure_opt as *mut Option<F> as *mut c_void;
708
+
709
+ // 4) Actually call `rb_thread_call_without_gvl`
710
+ let raw_result_ptr =
711
+ rb_thread_call_without_gvl(Some(trampoline::<F, R>), closure_ptr, None, null_mut());
712
+
713
+ // 5) Convert the returned pointer back into R
714
+ let result_box = Box::from_raw(raw_result_ptr as *mut R);
715
+ *result_box
716
+ }
717
+
718
+ pub unsafe fn call_with_gvl<F, R>(f: F) -> R
719
+ where
720
+ F: FnOnce() -> R,
721
+ {
722
+ // This is the function that Ruby calls “in the background” with the GVL released.
723
+ extern "C" fn trampoline<F, R>(arg: *mut c_void) -> *mut c_void
724
+ where
725
+ F: FnOnce() -> R,
726
+ {
727
+ // 1) Reconstruct the Box that holds our closure
728
+ let closure_ptr = arg as *mut Option<F>;
729
+ let closure = unsafe { (*closure_ptr).take().expect("Closure already taken") };
730
+
731
+ // 2) Call the user’s closure
732
+ let result = closure();
733
+
734
+ // 3) Box up the result so we can return a pointer to it
735
+ let boxed_result = Box::new(result);
736
+ Box::into_raw(boxed_result) as *mut c_void
737
+ }
738
+
739
+ // Box up the closure so we have a stable pointer
740
+ let mut closure_opt = Some(f);
741
+ let closure_ptr = &mut closure_opt as *mut Option<F> as *mut c_void;
742
+
743
+ // 4) Actually call `rb_thread_call_without_gvl`
744
+ let raw_result_ptr = rb_thread_call_with_gvl(Some(trampoline::<F, R>), closure_ptr);
745
+
746
+ // 5) Convert the returned pointer back into R
747
+ let result_box = Box::from_raw(raw_result_ptr as *mut R);
748
+ *result_box
749
+ }
750
+
751
+ unsafe extern "C" fn accept_ruby_connection_loop(ptr: *mut c_void) -> *mut c_void {
752
+ let serve_args = &mut *(ptr as *mut AcceptLoopArgs);
753
+ let mut builder: RuntimeBuilder = RuntimeBuilder::new_current_thread();
754
+
755
+ libc::signal(libc::SIGTERM, initiate_shutdown as usize);
756
+ libc::signal(libc::SIGINT, initiate_shutdown as usize);
757
+
758
+ let runtime = builder
759
+ .thread_name("osprey-accept-runtime")
760
+ .thread_stack_size(3 * 1024 * 1024)
761
+ .enable_io()
762
+ .enable_time()
763
+ .build()
764
+ .expect("Failed to build Tokio runtime");
765
+
766
+ let (sender, receiver) = crossbeam::channel::unbounded::<RackRequest>();
767
+ start_ruby_thread_pool(serve_args.threads as usize, Arc::new(receiver));
768
+
769
+ let _ = runtime.block_on(accept_ruby_connection_async(serve_args, sender));
770
+
771
+ std::ptr::null_mut()
772
+ }
773
+
774
+ unsafe extern "C" fn accept_loop(ptr: *mut c_void) -> *mut c_void {
775
+ let serve_args = &mut *(ptr as *mut AcceptLoopArgs);
776
+ let mut builder: RuntimeBuilder = if serve_args.threads == 1 {
777
+ RuntimeBuilder::new_current_thread()
778
+ } else {
779
+ RuntimeBuilder::new_multi_thread()
780
+ };
781
+ let runtime = builder
782
+ .worker_threads(serve_args.threads as usize)
783
+ .thread_name("osprey-accept-runtime")
784
+ .thread_stack_size(3 * 1024 * 1024)
785
+ .enable_io()
786
+ .enable_time()
787
+ .build()
788
+ .expect("Failed to build Tokio runtime");
789
+
790
+ libc::signal(libc::SIGTERM, initiate_shutdown as usize);
791
+ libc::signal(libc::SIGINT, initiate_shutdown as usize);
792
+
793
+ let _ = runtime.block_on(accept_loop_async(serve_args));
794
+ runtime.shutdown_timeout(Duration::from_millis(2000));
795
+
796
+ info!(
797
+ "(PID={}) accept_loop finished, shutting down.",
798
+ libc::getpid()
799
+ );
800
+
801
+ libc::signal(libc::SIGTERM, libc::SIG_DFL);
802
+ libc::signal(libc::SIGINT, libc::SIG_DFL);
803
+
804
+ debug!(
805
+ "(PID={}) HTTP listener state: {:?}",
806
+ libc::getpid(),
807
+ serve_args.http_listener
808
+ );
809
+ debug!("(PID={}) Dropping STD listeners", libc::getpid());
810
+ drop(serve_args.http_listener.take());
811
+ drop(serve_args.https_listener.take());
812
+
813
+ std::ptr::null_mut()
814
+ }
815
+
816
+ unsafe extern "C" fn gather_requests_no_gvl(ptr: *mut c_void) -> *mut c_void {
817
+ let receiver = unsafe { &*(ptr as *const Arc<Receiver<RackRequest>>) };
818
+
819
+ let batch_size = 25;
820
+ let mut batch = Vec::with_capacity(batch_size);
821
+
822
+ // Try to collect a batch of up to `batch_size` requests
823
+ while batch.len() < batch_size {
824
+ match receiver.try_recv() {
825
+ Ok(request) => batch.push(request),
826
+ Err(_) => break,
827
+ }
828
+ }
829
+
830
+ // If the batch is empty, wait for a single request
831
+ if batch.is_empty() {
832
+ match receiver.recv_timeout(Duration::from_micros(100)) {
833
+ Ok(request) => batch.push(request),
834
+ Err(_) => (),
835
+ }
836
+ }
837
+
838
+ // Process the collected batch of requests
839
+ Box::into_raw(Box::new(batch)) as *mut c_void
840
+ }
841
+
842
+ unsafe fn gather_requests(receiver: &Arc<Receiver<RackRequest>>) -> Vec<RackRequest> {
843
+ let data_ptr = receiver as *const Arc<Receiver<RackRequest>> as *mut c_void;
844
+ let result_ptr =
845
+ rb_thread_call_without_gvl(Some(gather_requests_no_gvl), data_ptr, None, null_mut());
846
+ let result = Box::from_raw(result_ptr as *mut Vec<RackRequest>);
847
+ *result
848
+ }
849
+
850
+ unsafe extern "C" fn ruby_thread_worker(ptr: *mut c_void) -> u64 {
851
+ let ruby = Ruby::get_unchecked();
852
+ let receiver = *Box::from_raw(ptr as *mut Arc<Receiver<RackRequest>>);
853
+ let use_scheduler = true;
854
+
855
+ if use_scheduler {
856
+ let receiver_clone = receiver.clone();
857
+ let fiber_class = ruby.get_inner(&FIBER_CLASS);
858
+ let scheduler = ruby.get_inner(&SCHEDULER_CLASS).new_instance(()).unwrap();
859
+ fiber_class
860
+ .funcall::<_, _, Value>("set_scheduler", (scheduler,))
861
+ .unwrap();
862
+ let fib = ruby.proc_from_fn(move |_rb, _arg, _blk| {
863
+ let ruby = Ruby::get_unchecked();
864
+ let fiber_class = ruby.get_inner(&FIBER_CLASS);
865
+ let receiver = receiver_clone.clone();
866
+ let scheduler: Value = fiber_class.funcall("scheduler", ()).unwrap();
867
+ let id_yield = ruby.intern("yield");
868
+ let atomic_counter = Arc::new(AtomicUsize::new(0));
869
+ loop {
870
+ let current_count = atomic_counter.load(Ordering::Relaxed);
871
+ if current_count > 0 {
872
+ if let Err(_) = scheduler.funcall::<Id, (), Value>(id_yield, ()) {
873
+ break;
874
+ }
875
+ }
876
+ let batch = gather_requests(&receiver);
877
+ for request in batch {
878
+ let atomic_counter_clone = Arc::clone(&atomic_counter);
879
+ let fiber = ruby
880
+ .fiber_new_from_fn(Default::default(), move |ruby, _args, _block| {
881
+ atomic_counter_clone.fetch_add(1, Ordering::Relaxed);
882
+ request.process(ruby).expect("Error processing request");
883
+ atomic_counter_clone.fetch_sub(1, Ordering::Relaxed);
884
+ })
885
+ .unwrap();
886
+
887
+ scheduler.funcall::<_, _, Value>(*ID_RESUME, (fiber,)).ok();
888
+ }
889
+
890
+ if SHUTDOWN_REQUESTED.load(Ordering::Relaxed) {
891
+ break;
892
+ }
893
+ }
894
+ });
895
+ fiber_class
896
+ .funcall_with_block::<_, _, Value>(*ID_SCHEDULE, (), fib)
897
+ .expect("scheduled fiber");
898
+ } else {
899
+ loop {
900
+ let batch = gather_requests(&receiver);
901
+ for request in batch {
902
+ request.process(&ruby).expect("Error processing request");
903
+ }
904
+
905
+ if SHUTDOWN_REQUESTED.load(Ordering::Relaxed) {
906
+ break;
907
+ }
908
+ }
909
+ }
910
+ 0
911
+ }
912
+
913
+ fn start_ruby_thread_pool(pool_size: usize, receiver: Arc<Receiver<RackRequest>>) {
914
+ for _i in 0..pool_size {
915
+ unsafe {
916
+ rb_thread_create(
917
+ Some(ruby_thread_worker),
918
+ Box::into_raw(Box::new(receiver.clone())) as *mut c_void,
919
+ )
920
+ };
921
+ }
922
+ }
923
+
924
+ unsafe fn initiate_shutdown(signum: i32, action: sighandler_t) {
925
+ if SHUTDOWN_REQUESTED.load(Ordering::Relaxed) {
926
+ return;
927
+ }
928
+ SHUTDOWN_REQUESTED.store(true, Ordering::SeqCst);
929
+ info!(
930
+ "(PID={}) Initiating shutdown with signal {}: {}",
931
+ libc::getpid(),
932
+ signum,
933
+ action
934
+ );
935
+
936
+ if let Some(sender) = CTRL_C_SIGNAL_CHANNEL.write().unwrap().take() {
937
+ let _ = sender.send(());
938
+ }
939
+
940
+ let children = CHILDREN.read().unwrap();
941
+ for child in children.iter() {
942
+ if child.pid > 0 {
943
+ let sig_result = libc::kill(child.pid, libc::SIGTERM);
944
+ if sig_result == -1 {
945
+ eprint!("Failed to send SIGTERM to child {}", child.pid);
946
+ }
947
+ let mut status = 0;
948
+ let wait_result = libc::waitpid(child.pid, &mut status, libc::WNOHANG);
949
+ if wait_result != 0 {
950
+ libc::kill(child.pid, libc::SIGKILL);
951
+ libc::waitpid(child.pid, &mut status, 0);
952
+ }
953
+ }
954
+ }
955
+ }
956
+
957
+ struct PreforkArgs {
958
+ workers: u16,
959
+ before_fork: Option<Proc>,
960
+ after_fork: Option<Proc>,
961
+ }
962
+
963
+ pub unsafe extern "C" fn kill_other_threads(
964
+ _: *mut std::os::raw::c_void,
965
+ ) -> *mut std::os::raw::c_void {
966
+ let thread_class =
967
+ RClass::from_value(eval("Thread").expect("Failed to fetch Ruby Thread class"))
968
+ .expect("Failed to fetch Ruby Thread class");
969
+ let threads: RArray = thread_class
970
+ .funcall("list", ())
971
+ .expect("Failed to list Ruby threads");
972
+ let current: Value = thread_class
973
+ .funcall("current", ())
974
+ .expect("Failed to get current thread");
975
+
976
+ // Signal thread exit
977
+ for thr in threads {
978
+ let same: bool = thr
979
+ .funcall("==", (current,))
980
+ .expect("Failed to compare threads");
981
+ if !same {
982
+ let _: Option<Value> = thr.funcall("exit", ()).expect("Failed to exit thread");
983
+ }
984
+ }
985
+
986
+ // Attempt to join
987
+ for thr in threads {
988
+ let same: bool = thr
989
+ .funcall("==", (current,))
990
+ .expect("Failed to compare threads");
991
+ if !same {
992
+ let _: Option<Value> = thr
993
+ .funcall("join", (0.5_f64,))
994
+ .expect("Failed to join thread");
995
+ }
996
+ }
997
+
998
+ // Terminate
999
+ for thr in threads {
1000
+ let same: bool = thr
1001
+ .funcall("==", (current,))
1002
+ .expect("Failed to compare threads");
1003
+ if !same {
1004
+ let alive: bool = thr
1005
+ .funcall("alive?", ())
1006
+ .expect("Failed to check if thread is alive");
1007
+ if alive {
1008
+ let _: Option<Value> = thr.funcall("kill", ()).expect("Failed to kill thread");
1009
+ }
1010
+ }
1011
+ }
1012
+
1013
+ std::ptr::null_mut()
1014
+ }
1015
+
1016
+ pub unsafe extern "C" fn before_fork(ptr: *mut std::os::raw::c_void) -> *mut std::os::raw::c_void {
1017
+ let pf_args = &mut *(ptr as *mut PreforkArgs);
1018
+ debug!("(PID={}) Before fork", libc::getpid());
1019
+ if let Some(ref p) = pf_args.before_fork {
1020
+ let _ = p.call::<(), Value>(());
1021
+ };
1022
+ std::ptr::null_mut()
1023
+ }
1024
+
1025
+ pub unsafe extern "C" fn after_fork(ptr: *mut std::os::raw::c_void) -> *mut std::os::raw::c_void {
1026
+ let pf_args = &mut *(ptr as *mut PreforkArgs);
1027
+ debug!("(PID={}) After fork", libc::getpid());
1028
+ if let Some(ref p) = pf_args.after_fork {
1029
+ let _ = p.call::<(), Value>(());
1030
+ };
1031
+ std::ptr::null_mut()
1032
+ }
1033
+
1034
+ pub unsafe extern "C" fn rb_fork(_: *mut std::os::raw::c_void) -> *mut std::os::raw::c_void {
1035
+ let ruby = Ruby::get_unchecked();
1036
+ let pid: Option<i32> = ruby
1037
+ .module_kernel()
1038
+ .funcall(ruby.intern("fork"), ())
1039
+ .unwrap();
1040
+ if let Some(child_pid) = pid {
1041
+ debug!(
1042
+ "(PID={}) Forked child process with PID {}",
1043
+ libc::getpid(),
1044
+ child_pid
1045
+ );
1046
+ return child_pid as *mut std::os::raw::c_void;
1047
+ } else {
1048
+ debug!("(PID={}) New child process starting ", libc::getpid());
1049
+ std::ptr::null_mut()
1050
+ }
1051
+ }
1052
+
1053
+ unsafe extern "C" fn prefork_setup(ptr: *mut c_void) -> *mut c_void {
1054
+ let pf_args = &*(ptr as *mut PreforkArgs);
1055
+ let workers = pf_args.workers;
1056
+ if workers >= 1 {
1057
+ let mut children = Vec::with_capacity(workers as usize - 1);
1058
+ rb_thread_call_with_gvl(Some(before_fork), ptr);
1059
+ rb_thread_call_with_gvl(Some(kill_other_threads), std::ptr::null_mut());
1060
+
1061
+ for _i in 0..workers {
1062
+ let (read, write) = socketpair(
1063
+ AddressFamily::Unix,
1064
+ SockType::Stream,
1065
+ None,
1066
+ SockFlag::empty(),
1067
+ )
1068
+ .unwrap();
1069
+
1070
+ let response_ptr = rb_thread_call_with_gvl(Some(rb_fork), std::ptr::null_mut());
1071
+ let pid = response_ptr as i32;
1072
+ debug!("Got reader and writer");
1073
+
1074
+ match pid {
1075
+ 0 => {
1076
+ let _ = Box::into_raw(Box::new(write));
1077
+
1078
+ rb_thread_call_with_gvl(Some(after_fork), ptr);
1079
+ let devnull = OpenOptions::new()
1080
+ .write(true)
1081
+ .open("/dev/null")
1082
+ .expect("Failed to open /dev/null");
1083
+ unsafe {
1084
+ libc::dup2(devnull.as_raw_fd(), libc::STDIN_FILENO);
1085
+ }
1086
+ return Box::into_raw(Box::new(read)) as *mut c_void;
1087
+ }
1088
+ child_pid => {
1089
+ let _ = Box::into_raw(Box::new(read));
1090
+ children.push(ChildProcess {
1091
+ pid: child_pid,
1092
+ writer: Arc::new(write),
1093
+ })
1094
+ }
1095
+ }
1096
+ }
1097
+ CHILDREN.write().unwrap().extend(children);
1098
+ }
1099
+ return std::ptr::null_mut();
1100
+ }
1101
+
1102
+ // TLS configuration
1103
+ fn configure_tls(cert_path: Option<String>, key_path: Option<String>) -> Option<Arc<ServerConfig>> {
1104
+ if cert_path.is_none() || key_path.is_none() {
1105
+ return None;
1106
+ }
1107
+ use tokio_rustls::rustls::{Certificate, PrivateKey};
1108
+
1109
+ let cert_bytes =
1110
+ std::fs::read(cert_path.expect("Expected a cert_path")).expect("Failed to read cert file");
1111
+ let key_bytes =
1112
+ std::fs::read(key_path.expect("Expected a key_path")).expect("Failed to read key file");
1113
+
1114
+ let certs = vec![Certificate(cert_bytes)];
1115
+ let key = PrivateKey(key_bytes);
1116
+
1117
+ let mut config = ServerConfig::builder()
1118
+ .with_safe_defaults()
1119
+ .with_no_client_auth()
1120
+ .with_single_cert(certs, key)
1121
+ .expect("Failed to build TLS config");
1122
+ config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
1123
+
1124
+ Some(Arc::new(config))
1125
+ }
1126
+
1127
+ fn filter_keywords(keywords: RHash, allowed: &[&str]) -> RHash {
1128
+ let filtered = RHash::new();
1129
+ for key in keywords.funcall::<_, _, RArray>("keys", ()).unwrap() {
1130
+ let key_str: String = key.funcall("to_s", ()).unwrap();
1131
+ if allowed.contains(&key_str.as_str()) {
1132
+ filtered.aset(key, keywords.get(key).unwrap()).unwrap();
1133
+ }
1134
+ }
1135
+ filtered
1136
+ }
1137
+ /// The main "start(...)" function that Ruby calls via `Osprey.start`
1138
+ fn start_server(args: &[Value]) -> Result<(), Error> {
1139
+ SHUTDOWN_REQUESTED.store(false, Ordering::SeqCst);
1140
+
1141
+ let srv = RackServerConfig::from_args(args)?;
1142
+
1143
+ let mut http_listener: Option<TcpListener> = None;
1144
+ let mut https_listener: Option<TcpListener> = None;
1145
+ let mut reader: Option<OwnedFd> = None;
1146
+
1147
+ let mut pf_args = PreforkArgs {
1148
+ workers: srv.workers,
1149
+ before_fork: srv.before_fork,
1150
+ after_fork: srv.after_fork,
1151
+ };
1152
+
1153
+ let pf_args_ptr = &mut pf_args as *mut PreforkArgs as *mut c_void;
1154
+
1155
+ unsafe {
1156
+ info!("(PID={}) Parent process started.", libc::getpid());
1157
+ let ptr = rb_thread_call_without_gvl(
1158
+ Some(prefork_setup),
1159
+ pf_args_ptr,
1160
+ None,
1161
+ std::ptr::null_mut(),
1162
+ );
1163
+ if !ptr.is_null() {
1164
+ reader = Some(*Box::from_raw(ptr as *mut OwnedFd));
1165
+ }
1166
+ }
1167
+
1168
+ unsafe {
1169
+ let is_parent = !CHILDREN.read().unwrap().is_empty(); //|| srv.workers == 1;
1170
+
1171
+ if is_parent {
1172
+ (http_listener, https_listener) = setup_listeners(&srv)?;
1173
+ }
1174
+ let serve_args = AcceptLoopArgs {
1175
+ http_listener,
1176
+ https_listener,
1177
+ reader,
1178
+ cert_path: srv.cert_path,
1179
+ key_path: srv.key_path,
1180
+ threads: srv.threads,
1181
+ shutdown_timeout: srv.shutdown_timeout,
1182
+ script_name: srv.script_name,
1183
+ };
1184
+ let args_ptr = Box::into_raw(Box::new(serve_args)) as *mut c_void;
1185
+ if is_parent {
1186
+ call_without_gvl(|| 3 + 2);
1187
+ rb_thread_call_without_gvl(
1188
+ Some(accept_loop),
1189
+ args_ptr,
1190
+ Some(accept_loop_interrupt),
1191
+ std::ptr::null_mut(),
1192
+ );
1193
+ wait_for_children(srv.shutdown_timeout);
1194
+ CHILDREN.write().unwrap().clear();
1195
+ } else {
1196
+ rb_thread_call_without_gvl(
1197
+ Some(accept_ruby_connection_loop),
1198
+ args_ptr,
1199
+ None,
1200
+ std::ptr::null_mut(),
1201
+ );
1202
+ }
1203
+ }
1204
+
1205
+ debug!("(PID={}) accept_loop returned. Exiting.", unsafe {
1206
+ libc::getpid()
1207
+ });
1208
+ Ok(())
1209
+ }
1210
+
1211
+ unsafe fn is_pid_running(pid: i32) -> bool {
1212
+ let mut status = 0;
1213
+ let result = unsafe { libc::waitpid(pid, &mut status, libc::WNOHANG) };
1214
+ return result == 0;
1215
+ }
1216
+
1217
+ fn wait_for_children(shutdown_timeout: Duration) {
1218
+ unsafe {
1219
+ libc::signal(libc::SIGTERM, initiate_shutdown as usize);
1220
+ libc::signal(libc::SIGINT, initiate_shutdown as usize);
1221
+ }
1222
+
1223
+ for child_process in CHILDREN.read().unwrap().iter() {
1224
+ let mut status = 0;
1225
+
1226
+ loop {
1227
+ if !unsafe { is_pid_running(child_process.pid) } {
1228
+ break;
1229
+ }
1230
+
1231
+ let wait_result =
1232
+ unsafe { libc::waitpid(child_process.pid, &mut status, libc::WNOHANG) };
1233
+
1234
+ if wait_result == -1 {
1235
+ // Error occurred, handle appropriately
1236
+
1237
+ debug!(
1238
+ "(PID={}) Error waiting for child process: {}",
1239
+ child_process.pid,
1240
+ std::io::Error::last_os_error()
1241
+ );
1242
+ break;
1243
+ } else if wait_result > 0 {
1244
+ if libc::WIFSIGNALED(status) || libc::WIFSTOPPED(status) || libc::WIFEXITED(status)
1245
+ {
1246
+ break;
1247
+ }
1248
+ } else {
1249
+ unsafe {
1250
+ if SHUTDOWN_REQUESTED.load(Ordering::Relaxed) {
1251
+ debug!(
1252
+ "(PID={}) Waiting for child {} to exit. Status: {}. Wait Result: {}",
1253
+ libc::getpid(),
1254
+ child_process.pid,
1255
+ status,
1256
+ wait_result
1257
+ );
1258
+ sleep(shutdown_timeout.as_secs() as u32 + 1);
1259
+ }
1260
+ };
1261
+ }
1262
+ }
1263
+ }
1264
+ }
1265
+
1266
+ impl RackServerConfig {
1267
+ fn from_args(args: &[Value]) -> Result<Self, Error> {
1268
+ let scan_args: Args<(), (), (), (), RHash, ()> = scan_args(args)?;
1269
+
1270
+ let args: KwArgs<
1271
+ (Value,),
1272
+ (
1273
+ Option<String>,
1274
+ Option<u16>,
1275
+ Option<u16>,
1276
+ Option<u16>,
1277
+ Option<u16>,
1278
+ Option<String>,
1279
+ Option<String>,
1280
+ Option<f64>,
1281
+ Option<String>,
1282
+ ),
1283
+ (),
1284
+ > = get_kwargs(
1285
+ filter_keywords(
1286
+ scan_args.keywords,
1287
+ &[
1288
+ "app",
1289
+ "host",
1290
+ "port",
1291
+ "http_port",
1292
+ "workers",
1293
+ "threads",
1294
+ "cert_path",
1295
+ "key_path",
1296
+ "shutdown_timeout",
1297
+ "script_name",
1298
+ ],
1299
+ ),
1300
+ &["app"],
1301
+ &[
1302
+ "host",
1303
+ "port",
1304
+ "http_port",
1305
+ "workers",
1306
+ "threads",
1307
+ "cert_path",
1308
+ "key_path",
1309
+ "shutdown_timeout",
1310
+ "script_name",
1311
+ ],
1312
+ )?;
1313
+
1314
+ let args2: KwArgs<(), (Option<Proc>, Option<Proc>), ()> = get_kwargs(
1315
+ filter_keywords(scan_args.keywords, &["before_fork", "after_fork"]),
1316
+ &[],
1317
+ &["before_fork", "after_fork"],
1318
+ )?;
1319
+
1320
+ let ruby = Ruby::get().unwrap();
1321
+ let host = args.optional.0.unwrap_or("127.0.0.1".to_owned());
1322
+ let port = args.optional.1.unwrap_or(3000);
1323
+ let http_port = args.optional.2;
1324
+ let workers = args.optional.3.unwrap_or(1);
1325
+ let threads = args.optional.4.unwrap_or(1);
1326
+ let cert_path = args.optional.5;
1327
+ let key_path = args.optional.6;
1328
+ let shutdown_timeout =
1329
+ Duration::from_millis((args.optional.7.unwrap_or(5.0) * 1000.0) as u64);
1330
+ let script_name = args.optional.8.unwrap_or("/".to_owned());
1331
+
1332
+ let app = args.required.0;
1333
+ let handler: Value = app.funcall("method", (*ID_CALL,)).unwrap();
1334
+ ruby.get_inner(&OSPREY)
1335
+ .ivar_set("@handler", handler)
1336
+ .unwrap();
1337
+ ruby.get_inner(&WORK_QUEUE)
1338
+ .funcall::<&str, (), Value>("clear", ())
1339
+ .ok();
1340
+ Ok(Self {
1341
+ host,
1342
+ port,
1343
+ http_port,
1344
+ workers,
1345
+ threads,
1346
+ cert_path,
1347
+ key_path,
1348
+ shutdown_timeout,
1349
+ script_name,
1350
+ before_fork: args2.optional.0,
1351
+ after_fork: args2.optional.1,
1352
+ })
1353
+ }
1354
+ }
1355
+ /// Standard Magnus init
1356
+ #[magnus::init]
1357
+ fn init(ruby: &magnus::Ruby) -> Result<(), Error> {
1358
+ SimpleLogger::new().env().init().unwrap();
1359
+
1360
+ ruby.get_inner(&OSPREY)
1361
+ .define_singleton_method("start", function!(start_server, -1))?;
1362
+
1363
+ define_collector(ruby)?;
1364
+ define_request_info(&ruby)?;
1365
+
1366
+ Ok(())
1367
+ }
1368
+
1369
+ // Mapping functions for requests and responses
1370
+ #[inline]
1371
+ async fn map_hyper_request(
1372
+ parts: hyper::http::request::Parts,
1373
+ peer_addr: SocketAddr,
1374
+ script_name: String,
1375
+ stream_fd: i32,
1376
+ ) -> (RackRequest, oneshot::Receiver<RackResponse>) {
1377
+ let (sender, receiver) = oneshot::channel();
1378
+
1379
+ let headers: HashMap<String, String> = parts
1380
+ .headers
1381
+ .iter()
1382
+ .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
1383
+ .collect();
1384
+
1385
+ let protocol = headers
1386
+ .get("upgrade")
1387
+ .and_then(|value| {
1388
+ Some(
1389
+ value
1390
+ .split(',')
1391
+ .map(|s| s.trim().to_string())
1392
+ .collect::<Vec<_>>(),
1393
+ )
1394
+ })
1395
+ .or_else(|| {
1396
+ headers.get("protocol").map(|value| {
1397
+ value
1398
+ .split(',')
1399
+ .map(|s| s.trim().to_string())
1400
+ .collect::<Vec<_>>()
1401
+ })
1402
+ })
1403
+ .unwrap_or_else(|| vec!["http".to_string()]);
1404
+
1405
+ unsafe {
1406
+ debug!("(PID={}) peer_adr: {:?}", libc::getpid(), peer_addr);
1407
+ debug!(
1408
+ "(PID={}) host header: {:?}",
1409
+ libc::getpid(),
1410
+ headers.get("host").unwrap_or(&"".to_string())
1411
+ );
1412
+ };
1413
+ let scheme = parts.uri.scheme_str().unwrap_or("http").to_string();
1414
+ let host = parts
1415
+ .uri
1416
+ .host()
1417
+ .unwrap_or(headers.get("host").unwrap_or(&"".to_string()))
1418
+ .to_string();
1419
+ let port = host
1420
+ .split(':')
1421
+ .nth(1)
1422
+ .unwrap_or(if &scheme == "https" { "443" } else { "80" })
1423
+ .parse::<i32>()
1424
+ .unwrap();
1425
+
1426
+ let query_string = parts.uri.query().unwrap_or("").to_string();
1427
+ (
1428
+ RackRequest(RefCell::new(Some(RackRequestInner {
1429
+ path: parts.uri.path().to_string(),
1430
+ method: parts.method.to_string(),
1431
+ version: format!("{:?}", parts.version),
1432
+ remote_addr: peer_addr.ip().to_string(),
1433
+ script_name,
1434
+ scheme,
1435
+ host,
1436
+ port,
1437
+ protocol,
1438
+ query_string,
1439
+ headers,
1440
+ stream_fd,
1441
+ sender: Arc::new(Mutex::new(Some(sender))),
1442
+ }))),
1443
+ receiver,
1444
+ )
1445
+ }
1446
+
1447
+ use hyper::body::{Frame, SizeHint};
1448
+ use std::task::{Context, Poll};
1449
+ use tokio::{
1450
+ io::{AsyncBufReadExt, AsyncReadExt, BufReader},
1451
+ sync::mpsc,
1452
+ };
1453
+
1454
+ #[derive(Debug)]
1455
+ pub struct HijackedStreamingBody {
1456
+ rx: mpsc::Receiver<Result<Bytes, io::Error>>,
1457
+ done: bool,
1458
+ }
1459
+
1460
+ impl hyper::body::Body for HijackedStreamingBody {
1461
+ type Data = Bytes;
1462
+ type Error = Infallible;
1463
+
1464
+ fn poll_frame(
1465
+ mut self: Pin<&mut Self>,
1466
+ cx: &mut Context<'_>,
1467
+ ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
1468
+ if self.done {
1469
+ return Poll::Ready(None);
1470
+ }
1471
+
1472
+ match Pin::new(&mut self.rx).poll_recv(cx) {
1473
+ Poll::Ready(Some(Ok(bytes))) => {
1474
+ let frame = Frame::data(bytes);
1475
+ Poll::Ready(Some(Ok(frame)))
1476
+ }
1477
+ Poll::Ready(Some(Err(e))) => {
1478
+ self.done = true;
1479
+ Poll::Ready(None)
1480
+ }
1481
+ Poll::Ready(None) => Poll::Ready(None),
1482
+ Poll::Pending => Poll::Pending,
1483
+ }
1484
+ }
1485
+ fn is_end_stream(&self) -> bool {
1486
+ false // rely on the channel closing
1487
+ }
1488
+
1489
+ fn size_hint(&self) -> SizeHint {
1490
+ SizeHint::default()
1491
+ }
1492
+ }
1493
+
1494
+ // ─── A HELPER TO PARSE THE HTTP/1.1 RESPONSE ───────────────────────────────
1495
+
1496
+ // Given a buffered reader reading from the hijacked connection, read until the header section
1497
+ // is complete (i.e. until "\r\n\r\n" is encountered), and then return the header lines and any
1498
+ // bytes that already belong to the body.
1499
+ async fn read_response_headers<R>(reader: &mut R) -> io::Result<(Vec<String>, Vec<u8>)>
1500
+ where
1501
+ R: AsyncBufReadExt + Unpin,
1502
+ {
1503
+ let mut headers_buf = Vec::new();
1504
+ // Read until we have the header terminator.
1505
+ loop {
1506
+ let mut line = String::new();
1507
+ let n = reader.read_line(&mut line).await?;
1508
+ if n == 0 {
1509
+ // EOF reached unexpectedly.
1510
+ break;
1511
+ }
1512
+ headers_buf.extend_from_slice(line.as_bytes());
1513
+ if headers_buf.ends_with(b"\r\n\r\n") || headers_buf.ends_with(b"\n\n") {
1514
+ break;
1515
+ }
1516
+ }
1517
+ // Split the headers from any already-read part of the body.
1518
+ let header_body_split = b"\r\n\r\n";
1519
+ if let Some(pos) = headers_buf
1520
+ .windows(header_body_split.len())
1521
+ .position(|window| window == header_body_split)
1522
+ {
1523
+ let header_lines = &headers_buf[..pos];
1524
+ let remaining = headers_buf[(pos + header_body_split.len())..].to_vec();
1525
+ // Split the header lines into strings.
1526
+ let headers = String::from_utf8_lossy(header_lines)
1527
+ .lines()
1528
+ .map(|s| s.to_string())
1529
+ .collect::<Vec<_>>();
1530
+ Ok((headers, remaining))
1531
+ } else {
1532
+ // If not found, return all as headers and no body data.
1533
+ let headers = String::from_utf8_lossy(&headers_buf)
1534
+ .lines()
1535
+ .map(|s| s.to_string())
1536
+ .collect::<Vec<_>>();
1537
+ Ok((headers, Vec::new()))
1538
+ }
1539
+ }
1540
+
1541
+ /// Given the raw header lines (from an HTTP/1.1 response), parse out the status and headers.
1542
+ ///
1543
+ /// This is a very simplistic parser; you may wish to use a real HTTP parser.
1544
+ fn parse_status_and_headers(lines: &[String]) -> (StatusCode, HeaderMap) {
1545
+ let mut status = StatusCode::OK;
1546
+ let mut header_map = HeaderMap::new();
1547
+
1548
+ if let Some(first_line) = lines.first() {
1549
+ // Expect a line like "HTTP/1.1 200 OK"
1550
+ let parts: Vec<&str> = first_line.split_whitespace().collect();
1551
+ if parts.len() >= 2 {
1552
+ if let Ok(code) = parts[1].parse::<u16>() {
1553
+ status = StatusCode::from_u16(code).unwrap_or(StatusCode::OK);
1554
+ }
1555
+ }
1556
+ }
1557
+ for line in &lines[1..] {
1558
+ if let Some((name, value)) = line.split_once(':') {
1559
+ header_map.insert(
1560
+ HeaderName::from_bytes(name.as_bytes()).unwrap(),
1561
+ HeaderValue::from_bytes(value.as_bytes()).unwrap(),
1562
+ );
1563
+ }
1564
+ }
1565
+ (status, header_map)
1566
+ }
1567
+ // ─── THE HIJACK HANDLING FUNCTION ───────────────────────────────────────────
1568
+
1569
+ async fn read_hijacked_headers(fd: i32) -> (HeaderMap, StatusCode, bool, Vec<u8>) {
1570
+ let hijack_stream =
1571
+ tokio::net::UnixStream::from_std(unsafe { UnixStream::from_raw_fd(fd) }).unwrap();
1572
+ let mut reader = BufReader::new(hijack_stream);
1573
+
1574
+ // Read until blank line => that's the "HTTP/1.1 101\r\nUpg...\r\n\r\n"
1575
+ let (header_lines, initial_body_bytes) = read_response_headers(&mut reader).await.unwrap();
1576
+
1577
+ // Parse them
1578
+ let (status, hdrs) = parse_status_and_headers(&header_lines);
1579
+ let requires_upgrade = status == StatusCode::SWITCHING_PROTOCOLS;
1580
+
1581
+ // Dont close the FD, we're going to use it later.
1582
+ let _ = reader.into_inner().into_std().unwrap().into_raw_fd();
1583
+
1584
+ (hdrs, status, requires_upgrade, initial_body_bytes)
1585
+ }
1586
+
1587
+ /// This function is called when rack_response.status == -1 (i.e. the app hijacked the stream).
1588
+ /// It takes the raw UnixStream (from the proxy side) and uses it to parse an HTTP/1.1 response
1589
+ /// written by the app. It then returns a Hyper Response that uses a streaming body.
1590
+ /// This function is called when `rack_response.status == -1`.
1591
+ /// We read the lines Ruby wrote over that hijacked FD, parse them,
1592
+ /// and build a streaming response. But if Ruby’s lines say `101`,
1593
+ /// we do not close the connection until Ruby closes its socket.
1594
+ async fn handle_hijacked_response(
1595
+ hijack_stream_fd: i32,
1596
+ parts: hyper::http::request::Parts,
1597
+ ) -> Result<Response<BoxBody<Bytes, Infallible>>, hyper::Error> {
1598
+ // We read the headers from the raw stream (in HTTP/1.1 format)
1599
+ // and will either upgrade the connection or use the body as a streaming body.
1600
+ let (headers, status, requires_upgrade, initial_body_bytes) =
1601
+ read_hijacked_headers(hijack_stream_fd).await;
1602
+
1603
+ let hijack_stream =
1604
+ tokio::net::UnixStream::from_std(unsafe { UnixStream::from_raw_fd(hijack_stream_fd) })
1605
+ .unwrap();
1606
+
1607
+ let mut response = if requires_upgrade {
1608
+ tokio::spawn(async move {
1609
+ let mut req = Request::from_parts(parts, Empty::<Bytes>::new());
1610
+ match hyper::upgrade::on(&mut req).await {
1611
+ Ok(upgraded) => {
1612
+ two_way_bridge(upgraded, hijack_stream)
1613
+ .await
1614
+ .expect("Error in creating two way bridge");
1615
+ }
1616
+ Err(e) => eprintln!("upgrade error: {:?}", e),
1617
+ }
1618
+ });
1619
+ Response::new(BoxBody::new(Empty::new()))
1620
+ } else {
1621
+ let (tx, rx) = mpsc::channel::<Result<Bytes, io::Error>>(16);
1622
+ if !initial_body_bytes.is_empty() {
1623
+ let _ = tx.send(Ok(Bytes::from(initial_body_bytes))).await;
1624
+ }
1625
+
1626
+ tokio::spawn(async move {
1627
+ let mut reader = BufReader::new(hijack_stream);
1628
+ let mut buf = [0u8; 1024];
1629
+ loop {
1630
+ match reader.read(&mut buf).await {
1631
+ Ok(0) => {
1632
+ info!("Hijacked socket closed. EOF");
1633
+ break;
1634
+ }
1635
+ Ok(n) => {
1636
+ // Forward chunk
1637
+ if tx
1638
+ .send(Ok(Bytes::copy_from_slice(&buf[..n])))
1639
+ .await
1640
+ .is_err()
1641
+ {
1642
+ info!("Client side dropped the channel");
1643
+ break;
1644
+ }
1645
+ }
1646
+ Err(e) => {
1647
+ info!("Read error: {:?}", e);
1648
+ let _ = tx.send(Err(e)).await;
1649
+ break;
1650
+ }
1651
+ }
1652
+ }
1653
+ });
1654
+ Response::new(BoxBody::new(HijackedStreamingBody { rx, done: false }))
1655
+ };
1656
+ *response.status_mut() = status;
1657
+ *response.headers_mut() = headers;
1658
+ Ok(response)
1659
+ }