osprey 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1659 @@
1
+ #![deny(unused_crate_dependencies)]
2
+
3
+ use bytes::Bytes;
4
+ use crossbeam::channel::{Receiver, Sender};
5
+ use http_body_util::combinators::BoxBody;
6
+ use http_body_util::{BodyExt, Empty};
7
+ use hyper::body::{Body, Incoming};
8
+ use hyper::header::{HeaderName, HeaderValue};
9
+ use hyper::service::service_fn;
10
+ use hyper::upgrade::Upgraded;
11
+ use hyper::{HeaderMap, Request, Response, StatusCode};
12
+ use hyper_util::rt::{TokioExecutor, TokioIo};
13
+ use hyper_util::server::conn::auto::Builder;
14
+ use libc::{sighandler_t, sleep};
15
+ use log::{debug, error, info, warn};
16
+ use magnus::value::{Id, Lazy, LazyId, ReprValue};
17
+ use magnus::{
18
+ block::Proc,
19
+ exception, function,
20
+ scan_args::{get_kwargs, scan_args, Args, KwArgs},
21
+ Error, Object, RHash, Value,
22
+ };
23
+ use magnus::{eval, Class, Module, RModule, Ruby};
24
+ use magnus::{RArray, RClass};
25
+ use nix::cmsg_space;
26
+ use nix::sys::socket::{
27
+ recvmsg, sendmsg, socketpair, AddressFamily, ControlMessage, ControlMessageOwned, MsgFlags,
28
+ SockFlag, SockType,
29
+ };
30
+ use std::cell::RefCell;
31
+ use std::convert::Infallible;
32
+ use std::os::unix::io::RawFd;
33
+ use std::os::unix::net::UnixStream;
34
+ use std::ptr::null_mut;
35
+ use structs::rack_request::{define_request_info, RackRequest, RackRequestInner};
36
+ use tokio::io::unix::AsyncFd;
37
+ use tokio::io::{AsyncRead, AsyncWrite};
38
+ use tokio::sync::broadcast::channel;
39
+ use tokio_stream::StreamExt;
40
+
41
+ use rb_sys::{rb_thread_call_with_gvl, rb_thread_call_without_gvl, rb_thread_create};
42
+ use simple_logger::SimpleLogger;
43
+ use socket2::{Domain, Protocol, Socket, Type};
44
+ use std::io::{IoSlice, IoSliceMut};
45
+ use std::net::{TcpListener, TcpStream};
46
+ use std::os::fd::{FromRawFd, IntoRawFd, OwnedFd};
47
+ use std::panic;
48
+ use std::pin::Pin;
49
+ use std::sync::atomic::Ordering;
50
+ use std::sync::atomic::{AtomicBool, AtomicUsize};
51
+ use std::sync::{Mutex, RwLock};
52
+ use std::{
53
+ collections::HashMap, ffi::c_void, fs::OpenOptions, net::SocketAddr, os::fd::AsRawFd,
54
+ sync::Arc, time::Duration,
55
+ };
56
+ use structs::rack_response::{define_collector, RackResponse};
57
+ use tokio::net::{TcpListener as TokioTcpListener, TcpStream as TokioTcpStream};
58
+ use tokio::runtime::Builder as RuntimeBuilder;
59
+ use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
60
+ use tokio::sync::oneshot::{self};
61
+ use tokio::task::{JoinError, JoinHandle};
62
+ use tokio::time::sleep as TokioSleep;
63
+ use tokio_rustls::rustls::ServerConfig;
64
+ use tokio_rustls::TlsAcceptor;
65
+
66
+ mod structs;
67
+
68
+ trait IoStream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {}
69
+ impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin> IoStream for T {}
70
+
71
+ /// --------------------------------------------------
72
+ /// Global/Static shared data for forking & shutdown
73
+ /// --------------------------------------------------
74
+
75
+ static CHILDREN: RwLock<Vec<ChildProcess>> = RwLock::new(vec![]);
76
+
77
+ static SHUTDOWN_REQUESTED: AtomicBool = AtomicBool::new(false);
78
+ static CTRL_C_SIGNAL_CHANNEL: RwLock<Option<tokio::sync::broadcast::Sender<()>>> =
79
+ RwLock::new(None);
80
+
81
+ pub static OSPREY: Lazy<RModule> = Lazy::new(|ruby| ruby.define_module("Osprey").unwrap());
82
+ pub static ID_SCHEDULE: LazyId = LazyId::new("schedule");
83
+ pub static ID_CALL: LazyId = LazyId::new("call");
84
+ pub static ID_NEW: LazyId = LazyId::new("new");
85
+ pub static ID_PUSH: LazyId = LazyId::new("push");
86
+ pub static ID_RESUME: LazyId = LazyId::new("resume");
87
+ pub static ID_TRANSFORM_VALUES: LazyId = LazyId::new("transform_values");
88
+ pub static ARRAY_PROC: Lazy<Proc> = Lazy::new(|ruby| {
89
+ ruby.module_kernel()
90
+ .funcall::<&str, (&str,), Value>("method", ("Array",))
91
+ .unwrap()
92
+ .funcall("to_proc", ())
93
+ .unwrap()
94
+ });
95
+ pub static WORK_QUEUE: Lazy<Value> = Lazy::new(|ruby| {
96
+ let queue = ruby
97
+ .define_class("Queue", ruby.class_object())
98
+ .unwrap()
99
+ .funcall::<&str, (), Value>("new", ())
100
+ .unwrap();
101
+ ruby.get_inner(&OSPREY)
102
+ .ivar_set("@work_queue", queue)
103
+ .unwrap();
104
+ queue
105
+ });
106
+
107
+ static SCHEDULER_CLASS: Lazy<RClass> = Lazy::new(|ruby| {
108
+ ruby.define_module("Async")
109
+ .unwrap()
110
+ .const_get("Scheduler")
111
+ .unwrap()
112
+ });
113
+
114
+ static FIBER_CLASS: Lazy<RClass> =
115
+ Lazy::new(|ruby| ruby.define_class("Fiber", ruby.class_object()).unwrap());
116
+ /// A simple struct to hold the server config + the Ruby app `Proc`
117
+ #[derive(Clone)]
118
+ struct RackServerConfig {
119
+ host: String,
120
+ port: u16,
121
+ http_port: Option<u16>,
122
+ workers: u16,
123
+ threads: u16,
124
+ cert_path: Option<String>,
125
+ key_path: Option<String>,
126
+ shutdown_timeout: Duration,
127
+ script_name: String,
128
+ before_fork: Option<Proc>,
129
+ after_fork: Option<Proc>,
130
+ }
131
+
132
+ #[derive(Debug, Clone)]
133
+ struct ChildProcess {
134
+ pid: i32,
135
+ writer: Arc<OwnedFd>,
136
+ }
137
+
138
+ struct AcceptLoopArgs {
139
+ http_listener: Option<TcpListener>,
140
+ https_listener: Option<TcpListener>,
141
+ shutdown_timeout: Duration,
142
+ threads: u16,
143
+ reader: Option<OwnedFd>,
144
+ script_name: String,
145
+ cert_path: Option<String>,
146
+ key_path: Option<String>,
147
+ }
148
+
149
+ unsafe extern "C" fn accept_loop_interrupt(_ptr: *mut c_void) {
150
+ info!("(PID={}) Interrupt accept loop", libc::getpid());
151
+ }
152
+
153
+ unsafe fn setup_listeners(
154
+ server: &RackServerConfig,
155
+ ) -> Result<(Option<TcpListener>, Option<TcpListener>), Error> {
156
+ let mut http_listener = None;
157
+ let mut https_listener = None;
158
+
159
+ // Set up HTTP listener if http_port is provided
160
+ if let Some(http_port) = server.http_port {
161
+ let socket = Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP)).map_err(|e| {
162
+ Error::new(
163
+ exception::io_error(),
164
+ format!("Socket creation failed: {}", e),
165
+ )
166
+ })?;
167
+ socket.set_reuse_address(true).ok();
168
+ socket.set_reuse_port(true).ok();
169
+ socket.set_nonblocking(true).ok();
170
+ socket.set_nodelay(true).ok();
171
+ socket.set_recv_buffer_size(1048576).ok();
172
+
173
+ let address: SocketAddr =
174
+ format!("{}:{}", server.host, http_port)
175
+ .parse()
176
+ .map_err(|e| {
177
+ Error::new(
178
+ exception::io_error(),
179
+ format!(
180
+ "Failed to parse address {}: {}",
181
+ format!("{}:{}", server.host, http_port),
182
+ e
183
+ ),
184
+ )
185
+ })?;
186
+ socket.bind(&address.into()).map_err(|e| {
187
+ Error::new(
188
+ exception::io_error(),
189
+ format!("bind failed for HTTP: {}", e),
190
+ )
191
+ })?;
192
+ socket.listen(1024).map_err(|e| {
193
+ Error::new(
194
+ exception::io_error(),
195
+ format!("listen failed for HTTP: {}", e),
196
+ )
197
+ })?;
198
+ let listener: TcpListener = socket.into();
199
+ listener.set_nonblocking(true).map_err(|e| {
200
+ Error::new(
201
+ exception::io_error(),
202
+ format!("set_nonblocking failed for HTTP: {}", e),
203
+ )
204
+ })?;
205
+ info!("(PID={}) Listening on http://{}", libc::getpid(), address);
206
+ http_listener = Some(listener);
207
+ }
208
+
209
+ // Set up HTTPS listener if cert_path and key_path are provided
210
+ if server.cert_path.is_some() && server.key_path.is_some() {
211
+ let socket = Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP)).map_err(|e| {
212
+ Error::new(
213
+ exception::io_error(),
214
+ format!("Socket creation failed: {}", e),
215
+ )
216
+ })?;
217
+ socket.set_reuse_address(true).ok();
218
+ socket.set_reuse_port(true).ok();
219
+ socket.set_nonblocking(true).ok();
220
+ socket.set_nodelay(true).ok();
221
+ socket.set_recv_buffer_size(1048576).ok();
222
+
223
+ let address: SocketAddr =
224
+ format!("{}:{}", server.host, server.port)
225
+ .parse()
226
+ .map_err(|e| {
227
+ Error::new(
228
+ exception::io_error(),
229
+ format!(
230
+ "Failed to parse address {}: {}",
231
+ format!("{}:{}", server.host, server.port),
232
+ e
233
+ ),
234
+ )
235
+ })?;
236
+ socket.bind(&address.into()).map_err(|e| {
237
+ Error::new(
238
+ exception::io_error(),
239
+ format!("bind failed for HTTPS: {}", e),
240
+ )
241
+ })?;
242
+ socket.listen(1024).map_err(|e| {
243
+ Error::new(
244
+ exception::io_error(),
245
+ format!("listen failed for HTTPS: {}", e),
246
+ )
247
+ })?;
248
+ let listener: TcpListener = socket.into();
249
+ listener.set_nonblocking(true).map_err(|e| {
250
+ Error::new(
251
+ exception::io_error(),
252
+ format!("set_nonblocking failed for HTTP: {}", e),
253
+ )
254
+ })?;
255
+ info!("(PID={}) Listening on https://{}", libc::getpid(), address);
256
+ https_listener = Some(listener);
257
+ }
258
+
259
+ Ok((http_listener, https_listener))
260
+ }
261
+
262
+ async fn await_fd(fd: RawFd) -> Result<(), JoinError> {
263
+ let async_fd_result = AsyncFd::new(fd);
264
+
265
+ if let Ok(async_fd) = async_fd_result {
266
+ return tokio::spawn(async move {
267
+ loop {
268
+ match async_fd.readable().await {
269
+ Ok(guard) => {
270
+ if guard.ready().is_read_closed() {
271
+ break;
272
+ } else {
273
+ let _ = TokioSleep(Duration::from_millis(1000)).await;
274
+ }
275
+ }
276
+ Err(e) => {
277
+ println!("Error waiting for readiness: {:?}", e);
278
+ break;
279
+ }
280
+ }
281
+ }
282
+ })
283
+ .await;
284
+ }
285
+ Ok(())
286
+ }
287
+
288
+ async unsafe fn accept_loop_async(serve_args: &mut AcceptLoopArgs) -> Result<(), Error> {
289
+ let tls_listener = serve_args.https_listener.as_ref().map(|listener| {
290
+ TokioTcpListener::from_std(listener.try_clone().expect("Clone STD Listener"))
291
+ .expect("Construct Tokio Listener from STD")
292
+ });
293
+ let non_tls_listener = serve_args.http_listener.as_ref().map(|listener| {
294
+ TokioTcpListener::from_std(listener.try_clone().expect("Clone STD Listener"))
295
+ .expect("Construct Tokio Listener from STD")
296
+ });
297
+
298
+ let (signal_tx, mut signal_rx) = channel(1);
299
+ let (shutdown_tx, mut shutdown_rx) = channel(1);
300
+ *CTRL_C_SIGNAL_CHANNEL.write().unwrap() = Some(signal_tx);
301
+ let children = CHILDREN.read().unwrap().to_vec();
302
+ let mut child_cycle = children.iter().cycle();
303
+
304
+ tokio::spawn(async move {
305
+ signal_rx.recv().await.ok();
306
+ info!(
307
+ "(PID={}) Signal received. Initiating shutdown",
308
+ libc::getpid()
309
+ );
310
+ shutdown_tx.send(()).ok();
311
+ });
312
+
313
+ loop {
314
+ tokio::select! {
315
+ conn = async { if let Some(listener) = tls_listener.as_ref() { Some(listener.accept().await) } else { None } }, if tls_listener.is_some() => {
316
+ if let Some(Ok((stream, _peer_addr))) = conn {
317
+ let fd = stream.into_std().expect("Failed to convert to std FD").into_raw_fd();
318
+
319
+ let iov = [IoSlice::new(b"1")];
320
+ let next_child = child_cycle.next().expect("Next child process");
321
+ sendmsg::<()>(
322
+ next_child.writer.as_raw_fd() as i32,
323
+ &iov,
324
+ &[ControlMessage::ScmRights(&[fd])], // Send the file descriptor
325
+ MsgFlags::empty(),
326
+ None, // No address
327
+ ).expect("Failed to send file descriptor");
328
+
329
+ if SHUTDOWN_REQUESTED.load(Ordering::Relaxed) {
330
+ break;
331
+ }
332
+ }
333
+ },
334
+ conn = async { if let Some(listener) = non_tls_listener.as_ref() { Some(listener.accept().await) } else { None } }, if non_tls_listener.is_some() => {
335
+ if let Some(Ok((stream, _peer_addr))) = conn {
336
+ let fd = stream.into_std().expect("Failed to convert to std FD").into_raw_fd();
337
+ let iov = [IoSlice::new(b"0")];
338
+ let next_child = child_cycle.next().expect("Next child process");
339
+ sendmsg::<()>(
340
+ next_child.writer.as_raw_fd() as i32,
341
+ &iov,
342
+ &[ControlMessage::ScmRights(&[fd])], // Send the file descriptor
343
+ MsgFlags::empty(),
344
+ None, // No address
345
+ ).expect("Failed to send file descriptor");
346
+ if SHUTDOWN_REQUESTED.load(Ordering::Relaxed) {
347
+ break;
348
+ }
349
+ }
350
+ }
351
+ _ = shutdown_rx.recv() => {
352
+ info!("(PID={}) Shutdown signal received. Breaking out of accept loop", libc::getpid());
353
+ break;
354
+ }
355
+ }
356
+ }
357
+
358
+ info!(
359
+ "(PID={}) Waiting up to {:?} for in-flight requests to complete",
360
+ libc::getpid(),
361
+ serve_args.shutdown_timeout
362
+ );
363
+
364
+ debug!("(PID={}) Dropping listeners", libc::getpid());
365
+ drop(tls_listener);
366
+ drop(non_tls_listener);
367
+ initiate_shutdown(0, 0);
368
+
369
+ Ok(())
370
+ }
371
+
372
+ use std::io;
373
+ use tokio::{io::AsyncWriteExt, net::UnixStream as TokioUnixStream};
374
+
375
+ async fn bridge_with_logging(
376
+ mut from: (impl AsyncRead + Unpin),
377
+ mut to: (impl AsyncWrite + Unpin),
378
+ direction_label: &str,
379
+ ) -> io::Result<u64> {
380
+ use tokio::io::{AsyncReadExt, AsyncWriteExt};
381
+ let mut buf = [0u8; 4096];
382
+ let mut total_bytes = 0u64;
383
+
384
+ loop {
385
+ let n = from.read(&mut buf).await?;
386
+ if n == 0 {
387
+ // EOF => no more data
388
+ break;
389
+ }
390
+ println!(
391
+ "[bridge {}] Read {} bytes: {:?}",
392
+ direction_label,
393
+ n,
394
+ &buf[..n]
395
+ );
396
+
397
+ to.write_all(&buf[..n]).await?;
398
+ total_bytes += n as u64;
399
+ }
400
+ Ok(total_bytes)
401
+ }
402
+
403
+ async fn two_way_bridge(upgraded: Upgraded, local: TokioUnixStream) -> io::Result<()> {
404
+ // Wrap the `Upgraded` with `TokioIo` so it implements AsyncRead+AsyncWrite
405
+ let client_io = TokioIo::new(upgraded);
406
+
407
+ // Split each side
408
+ let (mut lr, mut lw) = tokio::io::split(local);
409
+ let (mut cr, mut cw) = tokio::io::split(client_io);
410
+
411
+ let to_ruby = tokio::spawn(async move {
412
+ if let Err(e) = tokio::io::copy(&mut cr, &mut lw).await {
413
+ eprintln!("Error copying upgraded->local: {:?}", e);
414
+ }
415
+ });
416
+ let from_ruby = tokio::spawn(async move {
417
+ if let Err(e) = tokio::io::copy(&mut lr, &mut cw).await {
418
+ eprintln!("Error copying upgraded->local: {:?}", e);
419
+ }
420
+ });
421
+
422
+ let _ = to_ruby.await;
423
+ let _ = from_ruby.await;
424
+ Ok(())
425
+ }
426
+
427
+ async unsafe fn handle_hyper_request(
428
+ req: Request<Incoming>,
429
+ peer_addr: SocketAddr,
430
+ sender: Sender<RackRequest>,
431
+ script_name: String,
432
+ ) -> Result<Response<BoxBody<Bytes, Infallible>>, &'static str> {
433
+ let (parts, body) = req.into_parts();
434
+
435
+ let has_body = !body.is_end_stream();
436
+
437
+ // 1. (Optional) If there's a body, we'll expose it as an IO into our Ruby process.
438
+ let (local_fd, remote_fd) = if has_body {
439
+ let (local, remote) = TokioUnixStream::pair().map_err(|_| "socketpair failed")?;
440
+ let local_fd = local.as_raw_fd();
441
+ let (_lr, mut lw) = tokio::io::split(local);
442
+ let mut data_stream = body.into_data_stream();
443
+
444
+ // Stream all incoming data into our remote FD (Opened inside Ruby as an IO)
445
+ tokio::spawn(async move {
446
+ while let Some(Ok(chunk)) = data_stream.next().await {
447
+ info!("Got write chunk ! {:?}", chunk);
448
+ if lw.write_all(&chunk).await.is_err() {
449
+ break;
450
+ }
451
+ }
452
+ info!("Relinquishing ownership of local_fd");
453
+ _lr.unsplit(lw).into_std().unwrap().into_raw_fd()
454
+ });
455
+
456
+ let remote_fd = remote.into_std().unwrap().into_raw_fd();
457
+ (Some(local_fd), remote_fd)
458
+ } else {
459
+ (None, -1)
460
+ };
461
+
462
+ // 2) Build and send a RackRequest to our Rack app.
463
+ let (rack_req, resp_rx) =
464
+ map_hyper_request(parts.clone(), peer_addr, script_name, remote_fd).await;
465
+ sender.send(rack_req).expect("Failed to send to Ruby");
466
+
467
+ // 3) Then wait for our Rack app to return a RackResponse
468
+ let rack_response = resp_rx.await.map_err(|_| "Ruby channel closed")?;
469
+
470
+ // 4) We return a negative status if the response was hijacked. We'll reuse the body pipe
471
+ // if it already exists, otherwise, we can count on the hijacked response to have created a new one.
472
+ if rack_response.status < 0 {
473
+ let local_fd = match local_fd {
474
+ Some(ld) => ld,
475
+ None => (rack_response.status * -1) as i32,
476
+ };
477
+
478
+ return match handle_hijacked_response(local_fd, parts).await {
479
+ Ok(resp) => Ok(resp),
480
+ Err(_) => Err("Error handling hijacked response"),
481
+ };
482
+ }
483
+
484
+ // 5) Otherwise, normal HTTP response
485
+ Ok(rack_response.into_hyper_response())
486
+ }
487
+
488
+ async fn handle_connection(
489
+ input_stream: Pin<Box<tokio::net::TcpStream>>,
490
+ tls_acceptor: Option<TlsAcceptor>,
491
+ server: Arc<Builder<TokioExecutor>>,
492
+ peer_addr: SocketAddr,
493
+ sender: Sender<RackRequest>,
494
+ script_name: String,
495
+ ) {
496
+ unsafe {
497
+ debug!("(PID={}) Handling connection", libc::getpid());
498
+ }
499
+ let stream: TokioIo<_> = match &tls_acceptor {
500
+ Some(acceptor) => match acceptor.accept(input_stream).await {
501
+ Ok(tls_stream) => TokioIo::new(Box::pin(tls_stream) as Pin<Box<dyn IoStream>>),
502
+ Err(err) => {
503
+ warn!("TLS handshake error: {}", err);
504
+ return;
505
+ }
506
+ },
507
+ None => TokioIo::new(Box::pin(input_stream) as Pin<Box<dyn IoStream>>),
508
+ };
509
+
510
+ unsafe {
511
+ debug!("(PID={}) serve connection with upgrades", libc::getpid());
512
+ }
513
+
514
+ if let Err(err) = server
515
+ .serve_connection_with_upgrades(
516
+ stream,
517
+ service_fn(move |hyper_request| unsafe {
518
+ handle_hyper_request(
519
+ hyper_request,
520
+ peer_addr,
521
+ sender.clone(),
522
+ script_name.clone(),
523
+ )
524
+ }),
525
+ )
526
+ .await
527
+ {
528
+ debug!("Connection aborted: {:?}", err);
529
+ }
530
+
531
+ unsafe { debug!("(PID={}) Connection closed", libc::getpid()) };
532
+ }
533
+
534
+ async unsafe fn accept_ruby_connection_async(
535
+ serve_args: &mut AcceptLoopArgs,
536
+ sender: Sender<RackRequest>,
537
+ ) {
538
+ let reader = serve_args.reader.as_mut().unwrap();
539
+
540
+ drop(serve_args.http_listener.take());
541
+ drop(serve_args.https_listener.take());
542
+
543
+ let buf_raw = &mut [0u8; 1024];
544
+ let mut buf = [IoSliceMut::new(buf_raw)];
545
+ let server = Arc::new(Builder::new(TokioExecutor::new()));
546
+ let mut cmsg_space = cmsg_space!(Vec<RawFd>);
547
+
548
+ let (signal_tx, mut signal_rx) = channel(1);
549
+ let (shutdown_tx, mut shutdown_rx) = channel(1);
550
+ *CTRL_C_SIGNAL_CHANNEL.write().unwrap() = Some(signal_tx);
551
+
552
+ let mut current_requests: (
553
+ UnboundedSender<JoinHandle<()>>,
554
+ UnboundedReceiver<JoinHandle<()>>,
555
+ ) = tokio::sync::mpsc::unbounded_channel();
556
+
557
+ let in_flight_requests_handle = tokio::spawn(async move {
558
+ loop {
559
+ tokio::select! {
560
+ _ = signal_rx.recv() => {
561
+ info!("(PID={}) Signal received. Initiating shutdown", libc::getpid());
562
+ shutdown_tx.send(()).ok();
563
+ break
564
+ }
565
+ result = current_requests.1.recv() => {
566
+ match result {
567
+ Some(handle) => {
568
+ if let Err(e) = handle.await {
569
+ error!("Error handling request: {:?}", e);
570
+ }
571
+ },
572
+ _ => {
573
+ continue;
574
+ }
575
+ }
576
+ }
577
+ }
578
+ }
579
+ info!("(PID={}) Flushing final requests", libc::getpid());
580
+
581
+ info!(
582
+ "(PID={}) Awaiting in-flight request for clean shut down",
583
+ libc::getpid()
584
+ );
585
+ while let Ok(handle) = current_requests.1.try_recv() {
586
+ if let Err(e) = handle.await {
587
+ error!("Error handling request: {:?}", e);
588
+ }
589
+ }
590
+ });
591
+ let child_socket = Arc::new(AsyncFd::new(reader.as_raw_fd()).unwrap());
592
+ let tls_config = configure_tls(serve_args.cert_path.clone(), serve_args.key_path.clone());
593
+ let tls_acceptor = tls_config.map(TlsAcceptor::from);
594
+
595
+ info!("(PID={}) Starting main loop", libc::getpid());
596
+
597
+ loop {
598
+ tokio::select! {
599
+ guard = child_socket.readable() => {
600
+ match guard {
601
+ Ok(mut ready_guard) => {
602
+ while let Ok(msg) = recvmsg::<()>(
603
+ reader.as_raw_fd(),
604
+ &mut buf,
605
+ Some(&mut cmsg_space),
606
+ MsgFlags::MSG_DONTWAIT,
607
+ ) {
608
+ if msg.bytes == 0 {
609
+ info!("(PID={}) Received EOF. Exiting", libc::getpid());
610
+ return;
611
+ }
612
+
613
+ if let Some(ControlMessageOwned::ScmRights(fds)) = msg.cmsgs().unwrap().next() {
614
+ let fd: i32 = *fds.first().expect("File descriptor");
615
+ let tls: bool = buf[0][0] == '1' as u8;
616
+
617
+ let stream = TokioTcpStream::from_std(unsafe { TcpStream::from_raw_fd(fd) })
618
+ .expect("Tokio from std");
619
+
620
+
621
+ match stream.peer_addr(){
622
+ Ok(peer_addr) => {
623
+ let server_clone = server.clone();
624
+
625
+ let script_name_clone = serve_args.script_name.clone();
626
+ let tls_acceptor = if tls { tls_acceptor.clone() } else { None };
627
+ let sender_clone = sender.clone();
628
+ let handle = tokio::spawn(async move {
629
+ handle_connection(
630
+ Box::pin(stream),
631
+ tls_acceptor,
632
+ server_clone,
633
+ peer_addr,
634
+ sender_clone,
635
+ script_name_clone
636
+ )
637
+ .await;
638
+ });
639
+
640
+ info!("(PID={}) Spawned connection handler", libc::getpid());
641
+ if let Err(e) = current_requests.0.send(handle) {
642
+ error!("Failed to enqueue request: {:?}", e);
643
+ }
644
+ }
645
+ Err(e) => {
646
+ error!("Failed to get peer address: {:?}", e);
647
+ continue;
648
+ }
649
+ }
650
+
651
+ } else {
652
+ error!("No file descriptor received");
653
+ }
654
+ }
655
+
656
+ // Clear readiness state after processing
657
+ ready_guard.clear_ready();
658
+ }
659
+ Err(e) => {
660
+ error!("(PID={}) Failed to poll readability: {:?}", libc::getpid(), e);
661
+ }
662
+ }
663
+ }
664
+ _ = shutdown_rx.recv() => {
665
+ info!("(PID={}) Shutdown signal received. Breaking out of accept loop", libc::getpid());
666
+ break;
667
+ }
668
+ }
669
+ }
670
+
671
+ tokio::select! {
672
+ _ = in_flight_requests_handle => {
673
+ debug!("(PID={}) In-flight requests handled gracefully", libc::getpid());
674
+ }
675
+ _ = tokio::time::sleep(serve_args.shutdown_timeout) => {
676
+ debug!("(PID={}) Forced shutdown timeout reached", libc::getpid());
677
+ }
678
+ }
679
+
680
+ debug!("(PID={}) Dropping listeners", libc::getpid());
681
+ initiate_shutdown(0, 0);
682
+ }
683
+
684
+ pub unsafe fn call_without_gvl<F, R>(f: F) -> R
685
+ where
686
+ F: FnOnce() -> R,
687
+ {
688
+ // This is the function that Ruby calls “in the background” with the GVL released.
689
+ extern "C" fn trampoline<F, R>(arg: *mut c_void) -> *mut c_void
690
+ where
691
+ F: FnOnce() -> R,
692
+ {
693
+ // 1) Reconstruct the Box that holds our closure
694
+ let closure_ptr = arg as *mut Option<F>;
695
+ let closure = unsafe { (*closure_ptr).take().expect("Closure already taken") };
696
+
697
+ // 2) Call the user’s closure
698
+ let result = closure();
699
+
700
+ // 3) Box up the result so we can return a pointer to it
701
+ let boxed_result = Box::new(result);
702
+ Box::into_raw(boxed_result) as *mut c_void
703
+ }
704
+
705
+ // Box up the closure so we have a stable pointer
706
+ let mut closure_opt = Some(f);
707
+ let closure_ptr = &mut closure_opt as *mut Option<F> as *mut c_void;
708
+
709
+ // 4) Actually call `rb_thread_call_without_gvl`
710
+ let raw_result_ptr =
711
+ rb_thread_call_without_gvl(Some(trampoline::<F, R>), closure_ptr, None, null_mut());
712
+
713
+ // 5) Convert the returned pointer back into R
714
+ let result_box = Box::from_raw(raw_result_ptr as *mut R);
715
+ *result_box
716
+ }
717
+
718
+ pub unsafe fn call_with_gvl<F, R>(f: F) -> R
719
+ where
720
+ F: FnOnce() -> R,
721
+ {
722
+ // This is the function that Ruby calls “in the background” with the GVL released.
723
+ extern "C" fn trampoline<F, R>(arg: *mut c_void) -> *mut c_void
724
+ where
725
+ F: FnOnce() -> R,
726
+ {
727
+ // 1) Reconstruct the Box that holds our closure
728
+ let closure_ptr = arg as *mut Option<F>;
729
+ let closure = unsafe { (*closure_ptr).take().expect("Closure already taken") };
730
+
731
+ // 2) Call the user’s closure
732
+ let result = closure();
733
+
734
+ // 3) Box up the result so we can return a pointer to it
735
+ let boxed_result = Box::new(result);
736
+ Box::into_raw(boxed_result) as *mut c_void
737
+ }
738
+
739
+ // Box up the closure so we have a stable pointer
740
+ let mut closure_opt = Some(f);
741
+ let closure_ptr = &mut closure_opt as *mut Option<F> as *mut c_void;
742
+
743
+ // 4) Actually call `rb_thread_call_without_gvl`
744
+ let raw_result_ptr = rb_thread_call_with_gvl(Some(trampoline::<F, R>), closure_ptr);
745
+
746
+ // 5) Convert the returned pointer back into R
747
+ let result_box = Box::from_raw(raw_result_ptr as *mut R);
748
+ *result_box
749
+ }
750
+
751
+ unsafe extern "C" fn accept_ruby_connection_loop(ptr: *mut c_void) -> *mut c_void {
752
+ let serve_args = &mut *(ptr as *mut AcceptLoopArgs);
753
+ let mut builder: RuntimeBuilder = RuntimeBuilder::new_current_thread();
754
+
755
+ libc::signal(libc::SIGTERM, initiate_shutdown as usize);
756
+ libc::signal(libc::SIGINT, initiate_shutdown as usize);
757
+
758
+ let runtime = builder
759
+ .thread_name("osprey-accept-runtime")
760
+ .thread_stack_size(3 * 1024 * 1024)
761
+ .enable_io()
762
+ .enable_time()
763
+ .build()
764
+ .expect("Failed to build Tokio runtime");
765
+
766
+ let (sender, receiver) = crossbeam::channel::unbounded::<RackRequest>();
767
+ start_ruby_thread_pool(serve_args.threads as usize, Arc::new(receiver));
768
+
769
+ let _ = runtime.block_on(accept_ruby_connection_async(serve_args, sender));
770
+
771
+ std::ptr::null_mut()
772
+ }
773
+
774
+ unsafe extern "C" fn accept_loop(ptr: *mut c_void) -> *mut c_void {
775
+ let serve_args = &mut *(ptr as *mut AcceptLoopArgs);
776
+ let mut builder: RuntimeBuilder = if serve_args.threads == 1 {
777
+ RuntimeBuilder::new_current_thread()
778
+ } else {
779
+ RuntimeBuilder::new_multi_thread()
780
+ };
781
+ let runtime = builder
782
+ .worker_threads(serve_args.threads as usize)
783
+ .thread_name("osprey-accept-runtime")
784
+ .thread_stack_size(3 * 1024 * 1024)
785
+ .enable_io()
786
+ .enable_time()
787
+ .build()
788
+ .expect("Failed to build Tokio runtime");
789
+
790
+ libc::signal(libc::SIGTERM, initiate_shutdown as usize);
791
+ libc::signal(libc::SIGINT, initiate_shutdown as usize);
792
+
793
+ let _ = runtime.block_on(accept_loop_async(serve_args));
794
+ runtime.shutdown_timeout(Duration::from_millis(2000));
795
+
796
+ info!(
797
+ "(PID={}) accept_loop finished, shutting down.",
798
+ libc::getpid()
799
+ );
800
+
801
+ libc::signal(libc::SIGTERM, libc::SIG_DFL);
802
+ libc::signal(libc::SIGINT, libc::SIG_DFL);
803
+
804
+ debug!(
805
+ "(PID={}) HTTP listener state: {:?}",
806
+ libc::getpid(),
807
+ serve_args.http_listener
808
+ );
809
+ debug!("(PID={}) Dropping STD listeners", libc::getpid());
810
+ drop(serve_args.http_listener.take());
811
+ drop(serve_args.https_listener.take());
812
+
813
+ std::ptr::null_mut()
814
+ }
815
+
816
+ unsafe extern "C" fn gather_requests_no_gvl(ptr: *mut c_void) -> *mut c_void {
817
+ let receiver = unsafe { &*(ptr as *const Arc<Receiver<RackRequest>>) };
818
+
819
+ let batch_size = 25;
820
+ let mut batch = Vec::with_capacity(batch_size);
821
+
822
+ // Try to collect a batch of up to `batch_size` requests
823
+ while batch.len() < batch_size {
824
+ match receiver.try_recv() {
825
+ Ok(request) => batch.push(request),
826
+ Err(_) => break,
827
+ }
828
+ }
829
+
830
+ // If the batch is empty, wait for a single request
831
+ if batch.is_empty() {
832
+ match receiver.recv_timeout(Duration::from_micros(100)) {
833
+ Ok(request) => batch.push(request),
834
+ Err(_) => (),
835
+ }
836
+ }
837
+
838
+ // Process the collected batch of requests
839
+ Box::into_raw(Box::new(batch)) as *mut c_void
840
+ }
841
+
842
+ unsafe fn gather_requests(receiver: &Arc<Receiver<RackRequest>>) -> Vec<RackRequest> {
843
+ let data_ptr = receiver as *const Arc<Receiver<RackRequest>> as *mut c_void;
844
+ let result_ptr =
845
+ rb_thread_call_without_gvl(Some(gather_requests_no_gvl), data_ptr, None, null_mut());
846
+ let result = Box::from_raw(result_ptr as *mut Vec<RackRequest>);
847
+ *result
848
+ }
849
+
850
+ unsafe extern "C" fn ruby_thread_worker(ptr: *mut c_void) -> u64 {
851
+ let ruby = Ruby::get_unchecked();
852
+ let receiver = *Box::from_raw(ptr as *mut Arc<Receiver<RackRequest>>);
853
+ let use_scheduler = true;
854
+
855
+ if use_scheduler {
856
+ let receiver_clone = receiver.clone();
857
+ let fiber_class = ruby.get_inner(&FIBER_CLASS);
858
+ let scheduler = ruby.get_inner(&SCHEDULER_CLASS).new_instance(()).unwrap();
859
+ fiber_class
860
+ .funcall::<_, _, Value>("set_scheduler", (scheduler,))
861
+ .unwrap();
862
+ let fib = ruby.proc_from_fn(move |_rb, _arg, _blk| {
863
+ let ruby = Ruby::get_unchecked();
864
+ let fiber_class = ruby.get_inner(&FIBER_CLASS);
865
+ let receiver = receiver_clone.clone();
866
+ let scheduler: Value = fiber_class.funcall("scheduler", ()).unwrap();
867
+ let id_yield = ruby.intern("yield");
868
+ let atomic_counter = Arc::new(AtomicUsize::new(0));
869
+ loop {
870
+ let current_count = atomic_counter.load(Ordering::Relaxed);
871
+ if current_count > 0 {
872
+ if let Err(_) = scheduler.funcall::<Id, (), Value>(id_yield, ()) {
873
+ break;
874
+ }
875
+ }
876
+ let batch = gather_requests(&receiver);
877
+ for request in batch {
878
+ let atomic_counter_clone = Arc::clone(&atomic_counter);
879
+ let fiber = ruby
880
+ .fiber_new_from_fn(Default::default(), move |ruby, _args, _block| {
881
+ atomic_counter_clone.fetch_add(1, Ordering::Relaxed);
882
+ request.process(ruby).expect("Error processing request");
883
+ atomic_counter_clone.fetch_sub(1, Ordering::Relaxed);
884
+ })
885
+ .unwrap();
886
+
887
+ scheduler.funcall::<_, _, Value>(*ID_RESUME, (fiber,)).ok();
888
+ }
889
+
890
+ if SHUTDOWN_REQUESTED.load(Ordering::Relaxed) {
891
+ break;
892
+ }
893
+ }
894
+ });
895
+ fiber_class
896
+ .funcall_with_block::<_, _, Value>(*ID_SCHEDULE, (), fib)
897
+ .expect("scheduled fiber");
898
+ } else {
899
+ loop {
900
+ let batch = gather_requests(&receiver);
901
+ for request in batch {
902
+ request.process(&ruby).expect("Error processing request");
903
+ }
904
+
905
+ if SHUTDOWN_REQUESTED.load(Ordering::Relaxed) {
906
+ break;
907
+ }
908
+ }
909
+ }
910
+ 0
911
+ }
912
+
913
+ fn start_ruby_thread_pool(pool_size: usize, receiver: Arc<Receiver<RackRequest>>) {
914
+ for _i in 0..pool_size {
915
+ unsafe {
916
+ rb_thread_create(
917
+ Some(ruby_thread_worker),
918
+ Box::into_raw(Box::new(receiver.clone())) as *mut c_void,
919
+ )
920
+ };
921
+ }
922
+ }
923
+
924
+ unsafe fn initiate_shutdown(signum: i32, action: sighandler_t) {
925
+ if SHUTDOWN_REQUESTED.load(Ordering::Relaxed) {
926
+ return;
927
+ }
928
+ SHUTDOWN_REQUESTED.store(true, Ordering::SeqCst);
929
+ info!(
930
+ "(PID={}) Initiating shutdown with signal {}: {}",
931
+ libc::getpid(),
932
+ signum,
933
+ action
934
+ );
935
+
936
+ if let Some(sender) = CTRL_C_SIGNAL_CHANNEL.write().unwrap().take() {
937
+ let _ = sender.send(());
938
+ }
939
+
940
+ let children = CHILDREN.read().unwrap();
941
+ for child in children.iter() {
942
+ if child.pid > 0 {
943
+ let sig_result = libc::kill(child.pid, libc::SIGTERM);
944
+ if sig_result == -1 {
945
+ eprint!("Failed to send SIGTERM to child {}", child.pid);
946
+ }
947
+ let mut status = 0;
948
+ let wait_result = libc::waitpid(child.pid, &mut status, libc::WNOHANG);
949
+ if wait_result != 0 {
950
+ libc::kill(child.pid, libc::SIGKILL);
951
+ libc::waitpid(child.pid, &mut status, 0);
952
+ }
953
+ }
954
+ }
955
+ }
956
+
957
+ struct PreforkArgs {
958
+ workers: u16,
959
+ before_fork: Option<Proc>,
960
+ after_fork: Option<Proc>,
961
+ }
962
+
963
+ pub unsafe extern "C" fn kill_other_threads(
964
+ _: *mut std::os::raw::c_void,
965
+ ) -> *mut std::os::raw::c_void {
966
+ let thread_class =
967
+ RClass::from_value(eval("Thread").expect("Failed to fetch Ruby Thread class"))
968
+ .expect("Failed to fetch Ruby Thread class");
969
+ let threads: RArray = thread_class
970
+ .funcall("list", ())
971
+ .expect("Failed to list Ruby threads");
972
+ let current: Value = thread_class
973
+ .funcall("current", ())
974
+ .expect("Failed to get current thread");
975
+
976
+ // Signal thread exit
977
+ for thr in threads {
978
+ let same: bool = thr
979
+ .funcall("==", (current,))
980
+ .expect("Failed to compare threads");
981
+ if !same {
982
+ let _: Option<Value> = thr.funcall("exit", ()).expect("Failed to exit thread");
983
+ }
984
+ }
985
+
986
+ // Attempt to join
987
+ for thr in threads {
988
+ let same: bool = thr
989
+ .funcall("==", (current,))
990
+ .expect("Failed to compare threads");
991
+ if !same {
992
+ let _: Option<Value> = thr
993
+ .funcall("join", (0.5_f64,))
994
+ .expect("Failed to join thread");
995
+ }
996
+ }
997
+
998
+ // Terminate
999
+ for thr in threads {
1000
+ let same: bool = thr
1001
+ .funcall("==", (current,))
1002
+ .expect("Failed to compare threads");
1003
+ if !same {
1004
+ let alive: bool = thr
1005
+ .funcall("alive?", ())
1006
+ .expect("Failed to check if thread is alive");
1007
+ if alive {
1008
+ let _: Option<Value> = thr.funcall("kill", ()).expect("Failed to kill thread");
1009
+ }
1010
+ }
1011
+ }
1012
+
1013
+ std::ptr::null_mut()
1014
+ }
1015
+
1016
+ pub unsafe extern "C" fn before_fork(ptr: *mut std::os::raw::c_void) -> *mut std::os::raw::c_void {
1017
+ let pf_args = &mut *(ptr as *mut PreforkArgs);
1018
+ debug!("(PID={}) Before fork", libc::getpid());
1019
+ if let Some(ref p) = pf_args.before_fork {
1020
+ let _ = p.call::<(), Value>(());
1021
+ };
1022
+ std::ptr::null_mut()
1023
+ }
1024
+
1025
+ pub unsafe extern "C" fn after_fork(ptr: *mut std::os::raw::c_void) -> *mut std::os::raw::c_void {
1026
+ let pf_args = &mut *(ptr as *mut PreforkArgs);
1027
+ debug!("(PID={}) After fork", libc::getpid());
1028
+ if let Some(ref p) = pf_args.after_fork {
1029
+ let _ = p.call::<(), Value>(());
1030
+ };
1031
+ std::ptr::null_mut()
1032
+ }
1033
+
1034
+ pub unsafe extern "C" fn rb_fork(_: *mut std::os::raw::c_void) -> *mut std::os::raw::c_void {
1035
+ let ruby = Ruby::get_unchecked();
1036
+ let pid: Option<i32> = ruby
1037
+ .module_kernel()
1038
+ .funcall(ruby.intern("fork"), ())
1039
+ .unwrap();
1040
+ if let Some(child_pid) = pid {
1041
+ debug!(
1042
+ "(PID={}) Forked child process with PID {}",
1043
+ libc::getpid(),
1044
+ child_pid
1045
+ );
1046
+ return child_pid as *mut std::os::raw::c_void;
1047
+ } else {
1048
+ debug!("(PID={}) New child process starting ", libc::getpid());
1049
+ std::ptr::null_mut()
1050
+ }
1051
+ }
1052
+
1053
+ unsafe extern "C" fn prefork_setup(ptr: *mut c_void) -> *mut c_void {
1054
+ let pf_args = &*(ptr as *mut PreforkArgs);
1055
+ let workers = pf_args.workers;
1056
+ if workers >= 1 {
1057
+ let mut children = Vec::with_capacity(workers as usize - 1);
1058
+ rb_thread_call_with_gvl(Some(before_fork), ptr);
1059
+ rb_thread_call_with_gvl(Some(kill_other_threads), std::ptr::null_mut());
1060
+
1061
+ for _i in 0..workers {
1062
+ let (read, write) = socketpair(
1063
+ AddressFamily::Unix,
1064
+ SockType::Stream,
1065
+ None,
1066
+ SockFlag::empty(),
1067
+ )
1068
+ .unwrap();
1069
+
1070
+ let response_ptr = rb_thread_call_with_gvl(Some(rb_fork), std::ptr::null_mut());
1071
+ let pid = response_ptr as i32;
1072
+ debug!("Got reader and writer");
1073
+
1074
+ match pid {
1075
+ 0 => {
1076
+ let _ = Box::into_raw(Box::new(write));
1077
+
1078
+ rb_thread_call_with_gvl(Some(after_fork), ptr);
1079
+ let devnull = OpenOptions::new()
1080
+ .write(true)
1081
+ .open("/dev/null")
1082
+ .expect("Failed to open /dev/null");
1083
+ unsafe {
1084
+ libc::dup2(devnull.as_raw_fd(), libc::STDIN_FILENO);
1085
+ }
1086
+ return Box::into_raw(Box::new(read)) as *mut c_void;
1087
+ }
1088
+ child_pid => {
1089
+ let _ = Box::into_raw(Box::new(read));
1090
+ children.push(ChildProcess {
1091
+ pid: child_pid,
1092
+ writer: Arc::new(write),
1093
+ })
1094
+ }
1095
+ }
1096
+ }
1097
+ CHILDREN.write().unwrap().extend(children);
1098
+ }
1099
+ return std::ptr::null_mut();
1100
+ }
1101
+
1102
+ // TLS configuration
1103
+ fn configure_tls(cert_path: Option<String>, key_path: Option<String>) -> Option<Arc<ServerConfig>> {
1104
+ if cert_path.is_none() || key_path.is_none() {
1105
+ return None;
1106
+ }
1107
+ use tokio_rustls::rustls::{Certificate, PrivateKey};
1108
+
1109
+ let cert_bytes =
1110
+ std::fs::read(cert_path.expect("Expected a cert_path")).expect("Failed to read cert file");
1111
+ let key_bytes =
1112
+ std::fs::read(key_path.expect("Expected a key_path")).expect("Failed to read key file");
1113
+
1114
+ let certs = vec![Certificate(cert_bytes)];
1115
+ let key = PrivateKey(key_bytes);
1116
+
1117
+ let mut config = ServerConfig::builder()
1118
+ .with_safe_defaults()
1119
+ .with_no_client_auth()
1120
+ .with_single_cert(certs, key)
1121
+ .expect("Failed to build TLS config");
1122
+ config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
1123
+
1124
+ Some(Arc::new(config))
1125
+ }
1126
+
1127
+ fn filter_keywords(keywords: RHash, allowed: &[&str]) -> RHash {
1128
+ let filtered = RHash::new();
1129
+ for key in keywords.funcall::<_, _, RArray>("keys", ()).unwrap() {
1130
+ let key_str: String = key.funcall("to_s", ()).unwrap();
1131
+ if allowed.contains(&key_str.as_str()) {
1132
+ filtered.aset(key, keywords.get(key).unwrap()).unwrap();
1133
+ }
1134
+ }
1135
+ filtered
1136
+ }
1137
+ /// The main "start(...)" function that Ruby calls via `Osprey.start`
1138
+ fn start_server(args: &[Value]) -> Result<(), Error> {
1139
+ SHUTDOWN_REQUESTED.store(false, Ordering::SeqCst);
1140
+
1141
+ let srv = RackServerConfig::from_args(args)?;
1142
+
1143
+ let mut http_listener: Option<TcpListener> = None;
1144
+ let mut https_listener: Option<TcpListener> = None;
1145
+ let mut reader: Option<OwnedFd> = None;
1146
+
1147
+ let mut pf_args = PreforkArgs {
1148
+ workers: srv.workers,
1149
+ before_fork: srv.before_fork,
1150
+ after_fork: srv.after_fork,
1151
+ };
1152
+
1153
+ let pf_args_ptr = &mut pf_args as *mut PreforkArgs as *mut c_void;
1154
+
1155
+ unsafe {
1156
+ info!("(PID={}) Parent process started.", libc::getpid());
1157
+ let ptr = rb_thread_call_without_gvl(
1158
+ Some(prefork_setup),
1159
+ pf_args_ptr,
1160
+ None,
1161
+ std::ptr::null_mut(),
1162
+ );
1163
+ if !ptr.is_null() {
1164
+ reader = Some(*Box::from_raw(ptr as *mut OwnedFd));
1165
+ }
1166
+ }
1167
+
1168
+ unsafe {
1169
+ let is_parent = !CHILDREN.read().unwrap().is_empty(); //|| srv.workers == 1;
1170
+
1171
+ if is_parent {
1172
+ (http_listener, https_listener) = setup_listeners(&srv)?;
1173
+ }
1174
+ let serve_args = AcceptLoopArgs {
1175
+ http_listener,
1176
+ https_listener,
1177
+ reader,
1178
+ cert_path: srv.cert_path,
1179
+ key_path: srv.key_path,
1180
+ threads: srv.threads,
1181
+ shutdown_timeout: srv.shutdown_timeout,
1182
+ script_name: srv.script_name,
1183
+ };
1184
+ let args_ptr = Box::into_raw(Box::new(serve_args)) as *mut c_void;
1185
+ if is_parent {
1186
+ call_without_gvl(|| 3 + 2);
1187
+ rb_thread_call_without_gvl(
1188
+ Some(accept_loop),
1189
+ args_ptr,
1190
+ Some(accept_loop_interrupt),
1191
+ std::ptr::null_mut(),
1192
+ );
1193
+ wait_for_children(srv.shutdown_timeout);
1194
+ CHILDREN.write().unwrap().clear();
1195
+ } else {
1196
+ rb_thread_call_without_gvl(
1197
+ Some(accept_ruby_connection_loop),
1198
+ args_ptr,
1199
+ None,
1200
+ std::ptr::null_mut(),
1201
+ );
1202
+ }
1203
+ }
1204
+
1205
+ debug!("(PID={}) accept_loop returned. Exiting.", unsafe {
1206
+ libc::getpid()
1207
+ });
1208
+ Ok(())
1209
+ }
1210
+
1211
+ unsafe fn is_pid_running(pid: i32) -> bool {
1212
+ let mut status = 0;
1213
+ let result = unsafe { libc::waitpid(pid, &mut status, libc::WNOHANG) };
1214
+ return result == 0;
1215
+ }
1216
+
1217
+ fn wait_for_children(shutdown_timeout: Duration) {
1218
+ unsafe {
1219
+ libc::signal(libc::SIGTERM, initiate_shutdown as usize);
1220
+ libc::signal(libc::SIGINT, initiate_shutdown as usize);
1221
+ }
1222
+
1223
+ for child_process in CHILDREN.read().unwrap().iter() {
1224
+ let mut status = 0;
1225
+
1226
+ loop {
1227
+ if !unsafe { is_pid_running(child_process.pid) } {
1228
+ break;
1229
+ }
1230
+
1231
+ let wait_result =
1232
+ unsafe { libc::waitpid(child_process.pid, &mut status, libc::WNOHANG) };
1233
+
1234
+ if wait_result == -1 {
1235
+ // Error occurred, handle appropriately
1236
+
1237
+ debug!(
1238
+ "(PID={}) Error waiting for child process: {}",
1239
+ child_process.pid,
1240
+ std::io::Error::last_os_error()
1241
+ );
1242
+ break;
1243
+ } else if wait_result > 0 {
1244
+ if libc::WIFSIGNALED(status) || libc::WIFSTOPPED(status) || libc::WIFEXITED(status)
1245
+ {
1246
+ break;
1247
+ }
1248
+ } else {
1249
+ unsafe {
1250
+ if SHUTDOWN_REQUESTED.load(Ordering::Relaxed) {
1251
+ debug!(
1252
+ "(PID={}) Waiting for child {} to exit. Status: {}. Wait Result: {}",
1253
+ libc::getpid(),
1254
+ child_process.pid,
1255
+ status,
1256
+ wait_result
1257
+ );
1258
+ sleep(shutdown_timeout.as_secs() as u32 + 1);
1259
+ }
1260
+ };
1261
+ }
1262
+ }
1263
+ }
1264
+ }
1265
+
1266
+ impl RackServerConfig {
1267
+ fn from_args(args: &[Value]) -> Result<Self, Error> {
1268
+ let scan_args: Args<(), (), (), (), RHash, ()> = scan_args(args)?;
1269
+
1270
+ let args: KwArgs<
1271
+ (Value,),
1272
+ (
1273
+ Option<String>,
1274
+ Option<u16>,
1275
+ Option<u16>,
1276
+ Option<u16>,
1277
+ Option<u16>,
1278
+ Option<String>,
1279
+ Option<String>,
1280
+ Option<f64>,
1281
+ Option<String>,
1282
+ ),
1283
+ (),
1284
+ > = get_kwargs(
1285
+ filter_keywords(
1286
+ scan_args.keywords,
1287
+ &[
1288
+ "app",
1289
+ "host",
1290
+ "port",
1291
+ "http_port",
1292
+ "workers",
1293
+ "threads",
1294
+ "cert_path",
1295
+ "key_path",
1296
+ "shutdown_timeout",
1297
+ "script_name",
1298
+ ],
1299
+ ),
1300
+ &["app"],
1301
+ &[
1302
+ "host",
1303
+ "port",
1304
+ "http_port",
1305
+ "workers",
1306
+ "threads",
1307
+ "cert_path",
1308
+ "key_path",
1309
+ "shutdown_timeout",
1310
+ "script_name",
1311
+ ],
1312
+ )?;
1313
+
1314
+ let args2: KwArgs<(), (Option<Proc>, Option<Proc>), ()> = get_kwargs(
1315
+ filter_keywords(scan_args.keywords, &["before_fork", "after_fork"]),
1316
+ &[],
1317
+ &["before_fork", "after_fork"],
1318
+ )?;
1319
+
1320
+ let ruby = Ruby::get().unwrap();
1321
+ let host = args.optional.0.unwrap_or("127.0.0.1".to_owned());
1322
+ let port = args.optional.1.unwrap_or(3000);
1323
+ let http_port = args.optional.2;
1324
+ let workers = args.optional.3.unwrap_or(1);
1325
+ let threads = args.optional.4.unwrap_or(1);
1326
+ let cert_path = args.optional.5;
1327
+ let key_path = args.optional.6;
1328
+ let shutdown_timeout =
1329
+ Duration::from_millis((args.optional.7.unwrap_or(5.0) * 1000.0) as u64);
1330
+ let script_name = args.optional.8.unwrap_or("/".to_owned());
1331
+
1332
+ let app = args.required.0;
1333
+ let handler: Value = app.funcall("method", (*ID_CALL,)).unwrap();
1334
+ ruby.get_inner(&OSPREY)
1335
+ .ivar_set("@handler", handler)
1336
+ .unwrap();
1337
+ ruby.get_inner(&WORK_QUEUE)
1338
+ .funcall::<&str, (), Value>("clear", ())
1339
+ .ok();
1340
+ Ok(Self {
1341
+ host,
1342
+ port,
1343
+ http_port,
1344
+ workers,
1345
+ threads,
1346
+ cert_path,
1347
+ key_path,
1348
+ shutdown_timeout,
1349
+ script_name,
1350
+ before_fork: args2.optional.0,
1351
+ after_fork: args2.optional.1,
1352
+ })
1353
+ }
1354
+ }
1355
+ /// Standard Magnus init
1356
+ #[magnus::init]
1357
+ fn init(ruby: &magnus::Ruby) -> Result<(), Error> {
1358
+ SimpleLogger::new().env().init().unwrap();
1359
+
1360
+ ruby.get_inner(&OSPREY)
1361
+ .define_singleton_method("start", function!(start_server, -1))?;
1362
+
1363
+ define_collector(ruby)?;
1364
+ define_request_info(&ruby)?;
1365
+
1366
+ Ok(())
1367
+ }
1368
+
1369
+ // Mapping functions for requests and responses
1370
+ #[inline]
1371
+ async fn map_hyper_request(
1372
+ parts: hyper::http::request::Parts,
1373
+ peer_addr: SocketAddr,
1374
+ script_name: String,
1375
+ stream_fd: i32,
1376
+ ) -> (RackRequest, oneshot::Receiver<RackResponse>) {
1377
+ let (sender, receiver) = oneshot::channel();
1378
+
1379
+ let headers: HashMap<String, String> = parts
1380
+ .headers
1381
+ .iter()
1382
+ .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
1383
+ .collect();
1384
+
1385
+ let protocol = headers
1386
+ .get("upgrade")
1387
+ .and_then(|value| {
1388
+ Some(
1389
+ value
1390
+ .split(',')
1391
+ .map(|s| s.trim().to_string())
1392
+ .collect::<Vec<_>>(),
1393
+ )
1394
+ })
1395
+ .or_else(|| {
1396
+ headers.get("protocol").map(|value| {
1397
+ value
1398
+ .split(',')
1399
+ .map(|s| s.trim().to_string())
1400
+ .collect::<Vec<_>>()
1401
+ })
1402
+ })
1403
+ .unwrap_or_else(|| vec!["http".to_string()]);
1404
+
1405
+ unsafe {
1406
+ debug!("(PID={}) peer_adr: {:?}", libc::getpid(), peer_addr);
1407
+ debug!(
1408
+ "(PID={}) host header: {:?}",
1409
+ libc::getpid(),
1410
+ headers.get("host").unwrap_or(&"".to_string())
1411
+ );
1412
+ };
1413
+ let scheme = parts.uri.scheme_str().unwrap_or("http").to_string();
1414
+ let host = parts
1415
+ .uri
1416
+ .host()
1417
+ .unwrap_or(headers.get("host").unwrap_or(&"".to_string()))
1418
+ .to_string();
1419
+ let port = host
1420
+ .split(':')
1421
+ .nth(1)
1422
+ .unwrap_or(if &scheme == "https" { "443" } else { "80" })
1423
+ .parse::<i32>()
1424
+ .unwrap();
1425
+
1426
+ let query_string = parts.uri.query().unwrap_or("").to_string();
1427
+ (
1428
+ RackRequest(RefCell::new(Some(RackRequestInner {
1429
+ path: parts.uri.path().to_string(),
1430
+ method: parts.method.to_string(),
1431
+ version: format!("{:?}", parts.version),
1432
+ remote_addr: peer_addr.ip().to_string(),
1433
+ script_name,
1434
+ scheme,
1435
+ host,
1436
+ port,
1437
+ protocol,
1438
+ query_string,
1439
+ headers,
1440
+ stream_fd,
1441
+ sender: Arc::new(Mutex::new(Some(sender))),
1442
+ }))),
1443
+ receiver,
1444
+ )
1445
+ }
1446
+
1447
+ use hyper::body::{Frame, SizeHint};
1448
+ use std::task::{Context, Poll};
1449
+ use tokio::{
1450
+ io::{AsyncBufReadExt, AsyncReadExt, BufReader},
1451
+ sync::mpsc,
1452
+ };
1453
+
1454
+ #[derive(Debug)]
1455
+ pub struct HijackedStreamingBody {
1456
+ rx: mpsc::Receiver<Result<Bytes, io::Error>>,
1457
+ done: bool,
1458
+ }
1459
+
1460
+ impl hyper::body::Body for HijackedStreamingBody {
1461
+ type Data = Bytes;
1462
+ type Error = Infallible;
1463
+
1464
+ fn poll_frame(
1465
+ mut self: Pin<&mut Self>,
1466
+ cx: &mut Context<'_>,
1467
+ ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
1468
+ if self.done {
1469
+ return Poll::Ready(None);
1470
+ }
1471
+
1472
+ match Pin::new(&mut self.rx).poll_recv(cx) {
1473
+ Poll::Ready(Some(Ok(bytes))) => {
1474
+ let frame = Frame::data(bytes);
1475
+ Poll::Ready(Some(Ok(frame)))
1476
+ }
1477
+ Poll::Ready(Some(Err(e))) => {
1478
+ self.done = true;
1479
+ Poll::Ready(None)
1480
+ }
1481
+ Poll::Ready(None) => Poll::Ready(None),
1482
+ Poll::Pending => Poll::Pending,
1483
+ }
1484
+ }
1485
+ fn is_end_stream(&self) -> bool {
1486
+ false // rely on the channel closing
1487
+ }
1488
+
1489
+ fn size_hint(&self) -> SizeHint {
1490
+ SizeHint::default()
1491
+ }
1492
+ }
1493
+
1494
+ // ─── A HELPER TO PARSE THE HTTP/1.1 RESPONSE ───────────────────────────────
1495
+
1496
+ // Given a buffered reader reading from the hijacked connection, read until the header section
1497
+ // is complete (i.e. until "\r\n\r\n" is encountered), and then return the header lines and any
1498
+ // bytes that already belong to the body.
1499
+ async fn read_response_headers<R>(reader: &mut R) -> io::Result<(Vec<String>, Vec<u8>)>
1500
+ where
1501
+ R: AsyncBufReadExt + Unpin,
1502
+ {
1503
+ let mut headers_buf = Vec::new();
1504
+ // Read until we have the header terminator.
1505
+ loop {
1506
+ let mut line = String::new();
1507
+ let n = reader.read_line(&mut line).await?;
1508
+ if n == 0 {
1509
+ // EOF reached unexpectedly.
1510
+ break;
1511
+ }
1512
+ headers_buf.extend_from_slice(line.as_bytes());
1513
+ if headers_buf.ends_with(b"\r\n\r\n") || headers_buf.ends_with(b"\n\n") {
1514
+ break;
1515
+ }
1516
+ }
1517
+ // Split the headers from any already-read part of the body.
1518
+ let header_body_split = b"\r\n\r\n";
1519
+ if let Some(pos) = headers_buf
1520
+ .windows(header_body_split.len())
1521
+ .position(|window| window == header_body_split)
1522
+ {
1523
+ let header_lines = &headers_buf[..pos];
1524
+ let remaining = headers_buf[(pos + header_body_split.len())..].to_vec();
1525
+ // Split the header lines into strings.
1526
+ let headers = String::from_utf8_lossy(header_lines)
1527
+ .lines()
1528
+ .map(|s| s.to_string())
1529
+ .collect::<Vec<_>>();
1530
+ Ok((headers, remaining))
1531
+ } else {
1532
+ // If not found, return all as headers and no body data.
1533
+ let headers = String::from_utf8_lossy(&headers_buf)
1534
+ .lines()
1535
+ .map(|s| s.to_string())
1536
+ .collect::<Vec<_>>();
1537
+ Ok((headers, Vec::new()))
1538
+ }
1539
+ }
1540
+
1541
+ /// Given the raw header lines (from an HTTP/1.1 response), parse out the status and headers.
1542
+ ///
1543
+ /// This is a very simplistic parser; you may wish to use a real HTTP parser.
1544
+ fn parse_status_and_headers(lines: &[String]) -> (StatusCode, HeaderMap) {
1545
+ let mut status = StatusCode::OK;
1546
+ let mut header_map = HeaderMap::new();
1547
+
1548
+ if let Some(first_line) = lines.first() {
1549
+ // Expect a line like "HTTP/1.1 200 OK"
1550
+ let parts: Vec<&str> = first_line.split_whitespace().collect();
1551
+ if parts.len() >= 2 {
1552
+ if let Ok(code) = parts[1].parse::<u16>() {
1553
+ status = StatusCode::from_u16(code).unwrap_or(StatusCode::OK);
1554
+ }
1555
+ }
1556
+ }
1557
+ for line in &lines[1..] {
1558
+ if let Some((name, value)) = line.split_once(':') {
1559
+ header_map.insert(
1560
+ HeaderName::from_bytes(name.as_bytes()).unwrap(),
1561
+ HeaderValue::from_bytes(value.as_bytes()).unwrap(),
1562
+ );
1563
+ }
1564
+ }
1565
+ (status, header_map)
1566
+ }
1567
+ // ─── THE HIJACK HANDLING FUNCTION ───────────────────────────────────────────
1568
+
1569
+ async fn read_hijacked_headers(fd: i32) -> (HeaderMap, StatusCode, bool, Vec<u8>) {
1570
+ let hijack_stream =
1571
+ tokio::net::UnixStream::from_std(unsafe { UnixStream::from_raw_fd(fd) }).unwrap();
1572
+ let mut reader = BufReader::new(hijack_stream);
1573
+
1574
+ // Read until blank line => that's the "HTTP/1.1 101\r\nUpg...\r\n\r\n"
1575
+ let (header_lines, initial_body_bytes) = read_response_headers(&mut reader).await.unwrap();
1576
+
1577
+ // Parse them
1578
+ let (status, hdrs) = parse_status_and_headers(&header_lines);
1579
+ let requires_upgrade = status == StatusCode::SWITCHING_PROTOCOLS;
1580
+
1581
+ // Dont close the FD, we're going to use it later.
1582
+ let _ = reader.into_inner().into_std().unwrap().into_raw_fd();
1583
+
1584
+ (hdrs, status, requires_upgrade, initial_body_bytes)
1585
+ }
1586
+
1587
+ /// This function is called when rack_response.status == -1 (i.e. the app hijacked the stream).
1588
+ /// It takes the raw UnixStream (from the proxy side) and uses it to parse an HTTP/1.1 response
1589
+ /// written by the app. It then returns a Hyper Response that uses a streaming body.
1590
+ /// This function is called when `rack_response.status == -1`.
1591
+ /// We read the lines Ruby wrote over that hijacked FD, parse them,
1592
+ /// and build a streaming response. But if Ruby’s lines say `101`,
1593
+ /// we do not close the connection until Ruby closes its socket.
1594
+ async fn handle_hijacked_response(
1595
+ hijack_stream_fd: i32,
1596
+ parts: hyper::http::request::Parts,
1597
+ ) -> Result<Response<BoxBody<Bytes, Infallible>>, hyper::Error> {
1598
+ // We read the headers from the raw stream (in HTTP/1.1 format)
1599
+ // and will either upgrade the connection or use the body as a streaming body.
1600
+ let (headers, status, requires_upgrade, initial_body_bytes) =
1601
+ read_hijacked_headers(hijack_stream_fd).await;
1602
+
1603
+ let hijack_stream =
1604
+ tokio::net::UnixStream::from_std(unsafe { UnixStream::from_raw_fd(hijack_stream_fd) })
1605
+ .unwrap();
1606
+
1607
+ let mut response = if requires_upgrade {
1608
+ tokio::spawn(async move {
1609
+ let mut req = Request::from_parts(parts, Empty::<Bytes>::new());
1610
+ match hyper::upgrade::on(&mut req).await {
1611
+ Ok(upgraded) => {
1612
+ two_way_bridge(upgraded, hijack_stream)
1613
+ .await
1614
+ .expect("Error in creating two way bridge");
1615
+ }
1616
+ Err(e) => eprintln!("upgrade error: {:?}", e),
1617
+ }
1618
+ });
1619
+ Response::new(BoxBody::new(Empty::new()))
1620
+ } else {
1621
+ let (tx, rx) = mpsc::channel::<Result<Bytes, io::Error>>(16);
1622
+ if !initial_body_bytes.is_empty() {
1623
+ let _ = tx.send(Ok(Bytes::from(initial_body_bytes))).await;
1624
+ }
1625
+
1626
+ tokio::spawn(async move {
1627
+ let mut reader = BufReader::new(hijack_stream);
1628
+ let mut buf = [0u8; 1024];
1629
+ loop {
1630
+ match reader.read(&mut buf).await {
1631
+ Ok(0) => {
1632
+ info!("Hijacked socket closed. EOF");
1633
+ break;
1634
+ }
1635
+ Ok(n) => {
1636
+ // Forward chunk
1637
+ if tx
1638
+ .send(Ok(Bytes::copy_from_slice(&buf[..n])))
1639
+ .await
1640
+ .is_err()
1641
+ {
1642
+ info!("Client side dropped the channel");
1643
+ break;
1644
+ }
1645
+ }
1646
+ Err(e) => {
1647
+ info!("Read error: {:?}", e);
1648
+ let _ = tx.send(Err(e)).await;
1649
+ break;
1650
+ }
1651
+ }
1652
+ }
1653
+ });
1654
+ Response::new(BoxBody::new(HijackedStreamingBody { rx, done: false }))
1655
+ };
1656
+ *response.status_mut() = status;
1657
+ *response.headers_mut() = headers;
1658
+ Ok(response)
1659
+ }