starpc 0.41.2 → 0.43.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +101 -20
- package/dist/mock/mock.pb.d.ts +1 -1
- package/dist/mock/mock.pb.js +4 -5
- package/dist/mock/mock_srpc.pb.d.ts +3 -3
- package/dist/mock/mock_srpc.pb.js +5 -5
- package/echo/Cargo.toml +21 -0
- package/echo/build.rs +15 -0
- package/echo/echo.pb.cc +405 -0
- package/echo/echo.pb.go +9 -27
- package/echo/echo.pb.h +364 -0
- package/echo/echo_srpc.pb.go +1 -1
- package/echo/gen/mod.rs +3 -0
- package/echo/main.rs +162 -0
- package/go.mod +18 -10
- package/go.sum +28 -18
- package/mock/mock.pb.cc +394 -0
- package/mock/mock.pb.go +9 -27
- package/mock/mock.pb.h +366 -0
- package/mock/mock.pb.ts +11 -13
- package/mock/mock_srpc.pb.go +1 -1
- package/mock/mock_srpc.pb.ts +12 -9
- package/package.json +27 -25
- package/srpc/Cargo.toml +26 -0
- package/srpc/build.rs +15 -0
- package/srpc/client.rs +356 -0
- package/srpc/codec.rs +225 -0
- package/srpc/error.rs +177 -0
- package/srpc/handler.rs +163 -0
- package/srpc/invoker.rs +192 -0
- package/srpc/lib.rs +107 -0
- package/srpc/message.rs +9 -0
- package/srpc/mux.rs +353 -0
- package/srpc/packet.rs +334 -0
- package/srpc/proto/mod.rs +10 -0
- package/srpc/rpc.rs +777 -0
- package/srpc/rpcproto.pb.cc +1381 -0
- package/srpc/rpcproto.pb.go +75 -183
- package/srpc/rpcproto.pb.h +1451 -0
- package/srpc/server.rs +337 -0
- package/srpc/stream.rs +304 -0
- package/srpc/testing.rs +290 -0
- package/srpc/tests/integration_test.rs +495 -0
- package/srpc/transport.rs +218 -0
- package/Makefile +0 -154
package/srpc/server.rs
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
//! Server implementation for starpc.
|
|
2
|
+
//!
|
|
3
|
+
//! This module provides the server-side API for handling incoming RPC calls.
|
|
4
|
+
//! The server supports all streaming patterns: unary, client streaming,
|
|
5
|
+
//! server streaming, and bidirectional streaming.
|
|
6
|
+
|
|
7
|
+
use bytes::Bytes;
|
|
8
|
+
use futures::StreamExt;
|
|
9
|
+
use std::sync::Arc;
|
|
10
|
+
use std::time::Duration;
|
|
11
|
+
use tokio::io::{AsyncRead, AsyncWrite};
|
|
12
|
+
use tokio_util::codec::FramedRead;
|
|
13
|
+
|
|
14
|
+
use crate::codec::PacketCodec;
|
|
15
|
+
use crate::error::{Error, Result};
|
|
16
|
+
use crate::invoker::Invoker;
|
|
17
|
+
use crate::packet::Validate;
|
|
18
|
+
use crate::proto::packet::Body;
|
|
19
|
+
use crate::rpc::{PacketWriter, ServerRpc};
|
|
20
|
+
use crate::stream::{Context, Stream};
|
|
21
|
+
use crate::transport::TransportPacketWriter;
|
|
22
|
+
|
|
23
|
+
/// Default timeout for graceful shutdown after handler completes.
|
|
24
|
+
const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(100);
|
|
25
|
+
|
|
26
|
+
/// Server configuration options.
|
|
27
|
+
#[derive(Clone, Debug)]
|
|
28
|
+
pub struct ServerConfig {
|
|
29
|
+
/// Timeout for graceful shutdown after handler completes.
|
|
30
|
+
pub shutdown_timeout: Duration,
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
impl Default for ServerConfig {
|
|
34
|
+
fn default() -> Self {
|
|
35
|
+
Self {
|
|
36
|
+
shutdown_timeout: DEFAULT_SHUTDOWN_TIMEOUT,
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
/// Server for handling incoming RPC connections.
|
|
42
|
+
///
|
|
43
|
+
/// The Server routes incoming RPC calls to the appropriate handler
|
|
44
|
+
/// via the provided Invoker (typically a Mux).
|
|
45
|
+
///
|
|
46
|
+
/// # Example
|
|
47
|
+
///
|
|
48
|
+
/// ```ignore
|
|
49
|
+
/// use starpc::{Server, Mux};
|
|
50
|
+
///
|
|
51
|
+
/// let mux = Arc::new(Mux::new());
|
|
52
|
+
/// mux.register(Arc::new(MyServiceHandler))?;
|
|
53
|
+
///
|
|
54
|
+
/// let server = Server::new(mux);
|
|
55
|
+
/// server.handle_stream(tcp_stream).await?;
|
|
56
|
+
/// ```
|
|
57
|
+
pub struct Server<I: Invoker> {
|
|
58
|
+
/// The invoker for routing RPC calls.
|
|
59
|
+
pub invoker: Arc<I>,
|
|
60
|
+
|
|
61
|
+
/// Server configuration.
|
|
62
|
+
config: ServerConfig,
|
|
63
|
+
|
|
64
|
+
/// Optional error handler for connection errors.
|
|
65
|
+
/// If not set, errors are silently ignored.
|
|
66
|
+
error_handler: Option<Arc<dyn Fn(Error) + Send + Sync>>,
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
impl<I: Invoker + 'static> Server<I> {
|
|
70
|
+
/// Creates a new server with the given invoker.
|
|
71
|
+
pub fn new(invoker: I) -> Self {
|
|
72
|
+
Self {
|
|
73
|
+
invoker: Arc::new(invoker),
|
|
74
|
+
config: ServerConfig::default(),
|
|
75
|
+
error_handler: None,
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
/// Creates a new server with a shared invoker.
|
|
80
|
+
pub fn with_arc(invoker: Arc<I>) -> Self {
|
|
81
|
+
Self {
|
|
82
|
+
invoker,
|
|
83
|
+
config: ServerConfig::default(),
|
|
84
|
+
error_handler: None,
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
/// Sets the server configuration.
|
|
89
|
+
pub fn with_config(mut self, config: ServerConfig) -> Self {
|
|
90
|
+
self.config = config;
|
|
91
|
+
self
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
/// Sets an error handler for connection errors.
|
|
95
|
+
///
|
|
96
|
+
/// The handler is called when an error occurs during stream handling.
|
|
97
|
+
/// This is useful for logging or metrics.
|
|
98
|
+
pub fn with_error_handler<F>(mut self, handler: F) -> Self
|
|
99
|
+
where
|
|
100
|
+
F: Fn(Error) + Send + Sync + 'static,
|
|
101
|
+
{
|
|
102
|
+
self.error_handler = Some(Arc::new(handler));
|
|
103
|
+
self
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
/// Reports an error through the error handler, if configured.
|
|
107
|
+
fn report_error(&self, err: Error) {
|
|
108
|
+
if let Some(ref handler) = self.error_handler {
|
|
109
|
+
handler(err);
|
|
110
|
+
}
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
/// Handles a single stream connection.
|
|
114
|
+
///
|
|
115
|
+
/// This reads packets from the stream, routes the RPC call to the
|
|
116
|
+
/// appropriate handler, and writes responses back.
|
|
117
|
+
///
|
|
118
|
+
/// The method returns when the RPC completes or an error occurs.
|
|
119
|
+
pub async fn handle_stream<T>(&self, transport: T) -> Result<()>
|
|
120
|
+
where
|
|
121
|
+
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
|
122
|
+
{
|
|
123
|
+
let (read_half, write_half) = tokio::io::split(transport);
|
|
124
|
+
|
|
125
|
+
// Create the packet writer.
|
|
126
|
+
let writer: Arc<dyn PacketWriter> = Arc::new(TransportPacketWriter::new(write_half));
|
|
127
|
+
|
|
128
|
+
// Create framed reader.
|
|
129
|
+
let mut framed = FramedRead::new(read_half, PacketCodec::new());
|
|
130
|
+
|
|
131
|
+
// Wait for the first packet (CallStart).
|
|
132
|
+
let first_packet = framed.next().await;
|
|
133
|
+
let call_start = match first_packet {
|
|
134
|
+
Some(Ok(packet)) => {
|
|
135
|
+
// Validate the packet
|
|
136
|
+
packet.validate()?;
|
|
137
|
+
|
|
138
|
+
match packet.body {
|
|
139
|
+
Some(Body::CallStart(cs)) => cs,
|
|
140
|
+
_ => return Err(Error::ExpectedCallStart),
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
Some(Err(e)) => return Err(e),
|
|
144
|
+
None => return Err(Error::StreamClosed),
|
|
145
|
+
};
|
|
146
|
+
|
|
147
|
+
// Validate service and method IDs (already done by packet.validate(),
|
|
148
|
+
// but double-check for clarity).
|
|
149
|
+
if call_start.rpc_service.is_empty() {
|
|
150
|
+
return Err(Error::EmptyServiceId);
|
|
151
|
+
}
|
|
152
|
+
if call_start.rpc_method.is_empty() {
|
|
153
|
+
return Err(Error::EmptyMethodId);
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
let service_id = call_start.rpc_service.clone();
|
|
157
|
+
let method_id = call_start.rpc_method.clone();
|
|
158
|
+
|
|
159
|
+
// Create the server RPC.
|
|
160
|
+
let ctx = Context::new();
|
|
161
|
+
let rpc = Arc::new(ServerRpc::from_call_start(ctx, call_start, writer));
|
|
162
|
+
|
|
163
|
+
// Spawn a task to read remaining packets.
|
|
164
|
+
let rpc_clone = rpc.clone();
|
|
165
|
+
let mut read_task = tokio::spawn(async move {
|
|
166
|
+
while let Some(result) = framed.next().await {
|
|
167
|
+
match result {
|
|
168
|
+
Ok(packet) => {
|
|
169
|
+
if rpc_clone.handle_packet(packet).await.is_err() {
|
|
170
|
+
break;
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
Err(_) => break,
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
let _ = rpc_clone.handle_stream_close(None).await;
|
|
177
|
+
});
|
|
178
|
+
|
|
179
|
+
// Invoke the method.
|
|
180
|
+
let stream: Box<dyn Stream> = Box::new(ServerStream { rpc: rpc.clone() });
|
|
181
|
+
let (found, result) = self
|
|
182
|
+
.invoker
|
|
183
|
+
.invoke_method(&service_id, &method_id, stream)
|
|
184
|
+
.await;
|
|
185
|
+
|
|
186
|
+
// Handle the result.
|
|
187
|
+
if !found {
|
|
188
|
+
// Send unimplemented error.
|
|
189
|
+
let _ = rpc.send_error("method not implemented".to_string()).await;
|
|
190
|
+
} else if let Err(e) = result {
|
|
191
|
+
// Send error response.
|
|
192
|
+
let _ = rpc.send_error(e.to_string()).await;
|
|
193
|
+
} else {
|
|
194
|
+
// Close send side on success.
|
|
195
|
+
let _ = rpc.close_send().await;
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
// Explicitly close the RPC to release the writer/transport.
|
|
199
|
+
// This ensures the connection doesn't remain open after the terminal response.
|
|
200
|
+
let _ = rpc.close().await;
|
|
201
|
+
|
|
202
|
+
// Wait for the read task to finish or timeout, then abort it.
|
|
203
|
+
// This gives time for the client to receive our response.
|
|
204
|
+
tokio::select! {
|
|
205
|
+
_ = &mut read_task => {
|
|
206
|
+
// Read task finished naturally
|
|
207
|
+
}
|
|
208
|
+
_ = tokio::time::sleep(self.config.shutdown_timeout) => {
|
|
209
|
+
// Timeout - abort the read task
|
|
210
|
+
read_task.abort();
|
|
211
|
+
}
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
Ok(())
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
/// Accepts and handles connections in a loop.
|
|
218
|
+
///
|
|
219
|
+
/// This is a convenience method that accepts connections from a listener
|
|
220
|
+
/// and spawns a task to handle each one.
|
|
221
|
+
///
|
|
222
|
+
/// Errors from individual connections are reported via the error handler
|
|
223
|
+
/// (if configured) but don't stop the server.
|
|
224
|
+
pub async fn serve<L, T>(&self, mut listener: L) -> Result<()>
|
|
225
|
+
where
|
|
226
|
+
L: futures::Stream<Item = std::io::Result<T>> + Unpin,
|
|
227
|
+
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
|
228
|
+
{
|
|
229
|
+
while let Some(result) = listener.next().await {
|
|
230
|
+
match result {
|
|
231
|
+
Ok(stream) => {
|
|
232
|
+
let server = self.clone_for_spawn();
|
|
233
|
+
tokio::spawn(async move {
|
|
234
|
+
if let Err(e) = server.handle_stream(stream).await {
|
|
235
|
+
server.report_error(e);
|
|
236
|
+
}
|
|
237
|
+
});
|
|
238
|
+
}
|
|
239
|
+
Err(e) => {
|
|
240
|
+
self.report_error(Error::Io(e));
|
|
241
|
+
}
|
|
242
|
+
}
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
Ok(())
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
/// Creates a clone of the server for spawning tasks.
|
|
249
|
+
fn clone_for_spawn(&self) -> Server<I> {
|
|
250
|
+
Server {
|
|
251
|
+
invoker: self.invoker.clone(),
|
|
252
|
+
config: self.config.clone(),
|
|
253
|
+
error_handler: self.error_handler.clone(),
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
/// Wrapper to provide Stream interface for ServerRpc.
|
|
259
|
+
struct ServerStream {
|
|
260
|
+
rpc: Arc<ServerRpc>,
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
#[async_trait::async_trait]
|
|
264
|
+
impl Stream for ServerStream {
|
|
265
|
+
fn context(&self) -> &Context {
|
|
266
|
+
self.rpc.context()
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
async fn send_bytes(&self, data: Bytes) -> Result<()> {
|
|
270
|
+
self.rpc.send_bytes(data).await
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
async fn recv_bytes(&self) -> Result<Bytes> {
|
|
274
|
+
self.rpc.recv_bytes().await
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
async fn close_send(&self) -> Result<()> {
|
|
278
|
+
self.rpc.close_send().await
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
async fn close(&self) -> Result<()> {
|
|
282
|
+
self.rpc.close().await
|
|
283
|
+
}
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
#[cfg(test)]
|
|
287
|
+
mod tests {
|
|
288
|
+
use super::*;
|
|
289
|
+
use crate::mux::Mux;
|
|
290
|
+
use tokio::io::duplex;
|
|
291
|
+
|
|
292
|
+
#[tokio::test]
|
|
293
|
+
async fn test_server_config() {
|
|
294
|
+
let mux = Mux::new();
|
|
295
|
+
let server = Server::new(mux)
|
|
296
|
+
.with_config(ServerConfig {
|
|
297
|
+
shutdown_timeout: Duration::from_secs(1),
|
|
298
|
+
});
|
|
299
|
+
|
|
300
|
+
assert_eq!(server.config.shutdown_timeout, Duration::from_secs(1));
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
#[tokio::test]
|
|
304
|
+
async fn test_server_with_error_handler() {
|
|
305
|
+
use std::sync::Mutex;
|
|
306
|
+
|
|
307
|
+
let errors: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
|
|
308
|
+
let errors_clone = errors.clone();
|
|
309
|
+
|
|
310
|
+
let mux = Mux::new();
|
|
311
|
+
let server = Server::new(mux)
|
|
312
|
+
.with_error_handler(move |e| {
|
|
313
|
+
errors_clone.lock().unwrap().push(e.to_string());
|
|
314
|
+
});
|
|
315
|
+
|
|
316
|
+
// Report an error
|
|
317
|
+
server.report_error(Error::StreamClosed);
|
|
318
|
+
|
|
319
|
+
let logged = errors.lock().unwrap();
|
|
320
|
+
assert_eq!(logged.len(), 1);
|
|
321
|
+
assert_eq!(logged[0], "stream closed");
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
#[tokio::test]
|
|
325
|
+
async fn test_server_missing_call_start() {
|
|
326
|
+
let mux = Mux::new();
|
|
327
|
+
let server = Server::with_arc(Arc::new(mux));
|
|
328
|
+
|
|
329
|
+
let (client_stream, server_stream) = duplex(1024);
|
|
330
|
+
|
|
331
|
+
// Close immediately without sending CallStart
|
|
332
|
+
drop(client_stream);
|
|
333
|
+
|
|
334
|
+
let result = server.handle_stream(server_stream).await;
|
|
335
|
+
assert!(matches!(result, Err(Error::StreamClosed)));
|
|
336
|
+
}
|
|
337
|
+
}
|
package/srpc/stream.rs
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
//! Stream abstraction for RPC communication.
|
|
2
|
+
//!
|
|
3
|
+
//! This module provides the core `Stream` trait for bidirectional RPC
|
|
4
|
+
//! communication, along with context management for cancellation.
|
|
5
|
+
|
|
6
|
+
use async_trait::async_trait;
|
|
7
|
+
use bytes::Bytes;
|
|
8
|
+
use prost::Message;
|
|
9
|
+
use std::sync::Arc;
|
|
10
|
+
use tokio_util::sync::CancellationToken;
|
|
11
|
+
|
|
12
|
+
use crate::error::{Error, Result};
|
|
13
|
+
|
|
14
|
+
/// Context for an RPC stream, providing cancellation support.
|
|
15
|
+
///
|
|
16
|
+
/// The Context wraps a `CancellationToken` and provides a convenient API
|
|
17
|
+
/// for managing RPC lifecycles. Child contexts can be created that will
|
|
18
|
+
/// be cancelled when the parent is cancelled.
|
|
19
|
+
///
|
|
20
|
+
/// # Example
|
|
21
|
+
///
|
|
22
|
+
/// ```rust,ignore
|
|
23
|
+
/// let ctx = Context::new();
|
|
24
|
+
///
|
|
25
|
+
/// // Check if cancelled
|
|
26
|
+
/// if ctx.is_cancelled() {
|
|
27
|
+
/// return Err(Error::Cancelled);
|
|
28
|
+
/// }
|
|
29
|
+
///
|
|
30
|
+
/// // Wait for cancellation
|
|
31
|
+
/// ctx.cancelled().await;
|
|
32
|
+
///
|
|
33
|
+
/// // Create a child context
|
|
34
|
+
/// let child = ctx.child();
|
|
35
|
+
///
|
|
36
|
+
/// // Cancel the context
|
|
37
|
+
/// ctx.cancel();
|
|
38
|
+
/// ```
|
|
39
|
+
#[derive(Debug, Clone)]
|
|
40
|
+
pub struct Context {
|
|
41
|
+
/// Cancellation token for this context.
|
|
42
|
+
cancel_token: CancellationToken,
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
impl Default for Context {
|
|
46
|
+
fn default() -> Self {
|
|
47
|
+
Self::new()
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
impl Context {
|
|
52
|
+
/// Creates a new context.
|
|
53
|
+
pub fn new() -> Self {
|
|
54
|
+
Self {
|
|
55
|
+
cancel_token: CancellationToken::new(),
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
/// Creates a new context with a cancellation token.
|
|
60
|
+
pub fn with_cancel_token(cancel_token: CancellationToken) -> Self {
|
|
61
|
+
Self { cancel_token }
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
/// Creates a child context that will be cancelled when the parent is cancelled.
|
|
65
|
+
pub fn child(&self) -> Self {
|
|
66
|
+
Self {
|
|
67
|
+
cancel_token: self.cancel_token.child_token(),
|
|
68
|
+
}
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
/// Returns the cancellation token.
|
|
72
|
+
pub fn cancel_token(&self) -> &CancellationToken {
|
|
73
|
+
&self.cancel_token
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
/// Cancels this context and all child contexts.
|
|
77
|
+
pub fn cancel(&self) {
|
|
78
|
+
self.cancel_token.cancel();
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
/// Returns true if this context is cancelled.
|
|
82
|
+
pub fn is_cancelled(&self) -> bool {
|
|
83
|
+
self.cancel_token.is_cancelled()
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
/// Waits until the context is cancelled.
|
|
87
|
+
///
|
|
88
|
+
/// This is useful for implementing cooperative cancellation in handlers.
|
|
89
|
+
pub async fn cancelled(&self) {
|
|
90
|
+
self.cancel_token.cancelled().await
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
/// Returns a future that completes when the context is cancelled.
|
|
94
|
+
///
|
|
95
|
+
/// Unlike `cancelled()`, this returns an owned future that can be stored.
|
|
96
|
+
pub fn cancellation(&self) -> impl std::future::Future<Output = ()> + Send + 'static {
|
|
97
|
+
let token = self.cancel_token.clone();
|
|
98
|
+
async move {
|
|
99
|
+
token.cancelled().await;
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
/// Stream trait for bidirectional RPC communication.
|
|
105
|
+
///
|
|
106
|
+
/// This trait provides message send/receive operations for streaming RPCs.
|
|
107
|
+
/// The trait is object-safe by using bytes for the core methods.
|
|
108
|
+
///
|
|
109
|
+
/// # Implementation Notes
|
|
110
|
+
///
|
|
111
|
+
/// - All methods are async and may block waiting for I/O or messages
|
|
112
|
+
/// - `send_bytes` may return `Error::Completed` if `close_send` was already called
|
|
113
|
+
/// - `recv_bytes` returns `Error::StreamClosed` when the remote closes
|
|
114
|
+
/// - `close` cancels the context and releases all resources
|
|
115
|
+
#[async_trait]
|
|
116
|
+
pub trait Stream: Send + Sync {
|
|
117
|
+
/// Returns the context for this stream.
|
|
118
|
+
///
|
|
119
|
+
/// The context can be used to check for cancellation or to get
|
|
120
|
+
/// a cancellation token for cooperative cancellation.
|
|
121
|
+
fn context(&self) -> &Context;
|
|
122
|
+
|
|
123
|
+
/// Sends raw bytes on the stream.
|
|
124
|
+
///
|
|
125
|
+
/// # Errors
|
|
126
|
+
///
|
|
127
|
+
/// - `Error::Completed` if `close_send` was already called
|
|
128
|
+
/// - `Error::StreamClosed` if the connection was closed
|
|
129
|
+
/// - `Error::Io` for transport errors
|
|
130
|
+
async fn send_bytes(&self, data: Bytes) -> Result<()>;
|
|
131
|
+
|
|
132
|
+
/// Receives raw bytes from the stream.
|
|
133
|
+
///
|
|
134
|
+
/// # Errors
|
|
135
|
+
///
|
|
136
|
+
/// - `Error::StreamClosed` if the remote closed the stream
|
|
137
|
+
/// - `Error::Remote` if the remote sent an error
|
|
138
|
+
/// - `Error::Cancelled` if the context was cancelled
|
|
139
|
+
async fn recv_bytes(&self) -> Result<Bytes>;
|
|
140
|
+
|
|
141
|
+
/// Closes the send side of the stream.
|
|
142
|
+
///
|
|
143
|
+
/// After calling this, `send_bytes` will return `Error::Completed`.
|
|
144
|
+
/// The receive side remains open until the remote closes.
|
|
145
|
+
async fn close_send(&self) -> Result<()>;
|
|
146
|
+
|
|
147
|
+
/// Closes both sides of the stream.
|
|
148
|
+
///
|
|
149
|
+
/// This cancels the context and releases all resources.
|
|
150
|
+
async fn close(&self) -> Result<()>;
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
/// Extension trait for typed message send/receive.
|
|
154
|
+
///
|
|
155
|
+
/// This trait provides convenience methods for sending and receiving
|
|
156
|
+
/// protobuf messages over a Stream.
|
|
157
|
+
#[async_trait]
|
|
158
|
+
pub trait StreamExt: Stream {
|
|
159
|
+
/// Sends a typed message on the stream.
|
|
160
|
+
///
|
|
161
|
+
/// The message is encoded using protobuf and sent as bytes.
|
|
162
|
+
///
|
|
163
|
+
/// # Type Parameters
|
|
164
|
+
///
|
|
165
|
+
/// - `M`: A protobuf message type that implements `prost::Message`
|
|
166
|
+
async fn msg_send<M: Message + Send + Sync>(&self, msg: &M) -> Result<()> {
|
|
167
|
+
let data = msg.encode_to_vec();
|
|
168
|
+
self.send_bytes(Bytes::from(data)).await
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
/// Receives a typed message from the stream.
|
|
172
|
+
///
|
|
173
|
+
/// The received bytes are decoded as the specified protobuf message type.
|
|
174
|
+
///
|
|
175
|
+
/// # Type Parameters
|
|
176
|
+
///
|
|
177
|
+
/// - `M`: A protobuf message type that implements `prost::Message + Default`
|
|
178
|
+
///
|
|
179
|
+
/// # Errors
|
|
180
|
+
///
|
|
181
|
+
/// - `Error::InvalidMessage` if the bytes cannot be decoded as the message type
|
|
182
|
+
/// - All errors from `recv_bytes`
|
|
183
|
+
async fn msg_recv<M: Message + Default>(&self) -> Result<M> {
|
|
184
|
+
let data = self.recv_bytes().await?;
|
|
185
|
+
M::decode(&data[..]).map_err(Error::InvalidMessage)
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
// Blanket implementation for all Stream types.
|
|
190
|
+
impl<T: Stream + ?Sized> StreamExt for T {}
|
|
191
|
+
|
|
192
|
+
// Blanket implementation for Arc<T> where T: Stream
|
|
193
|
+
#[async_trait]
|
|
194
|
+
impl<T: Stream + ?Sized> Stream for Arc<T> {
|
|
195
|
+
fn context(&self) -> &Context {
|
|
196
|
+
(**self).context()
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
async fn send_bytes(&self, data: Bytes) -> Result<()> {
|
|
200
|
+
(**self).send_bytes(data).await
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
async fn recv_bytes(&self) -> Result<Bytes> {
|
|
204
|
+
(**self).recv_bytes().await
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
async fn close_send(&self) -> Result<()> {
|
|
208
|
+
(**self).close_send().await
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
async fn close(&self) -> Result<()> {
|
|
212
|
+
(**self).close().await
|
|
213
|
+
}
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
// Blanket implementation for Box<T> where T: Stream
|
|
217
|
+
#[async_trait]
|
|
218
|
+
impl<T: Stream + ?Sized> Stream for Box<T> {
|
|
219
|
+
fn context(&self) -> &Context {
|
|
220
|
+
(**self).context()
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
async fn send_bytes(&self, data: Bytes) -> Result<()> {
|
|
224
|
+
(**self).send_bytes(data).await
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
async fn recv_bytes(&self) -> Result<Bytes> {
|
|
228
|
+
(**self).recv_bytes().await
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
async fn close_send(&self) -> Result<()> {
|
|
232
|
+
(**self).close_send().await
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
async fn close(&self) -> Result<()> {
|
|
236
|
+
(**self).close().await
|
|
237
|
+
}
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
/// A boxed Stream trait object.
|
|
241
|
+
pub type BoxStream = Box<dyn Stream>;
|
|
242
|
+
|
|
243
|
+
/// A reference-counted Stream.
|
|
244
|
+
pub type ArcStream = Arc<dyn Stream>;
|
|
245
|
+
|
|
246
|
+
#[cfg(test)]
|
|
247
|
+
mod tests {
|
|
248
|
+
use super::*;
|
|
249
|
+
|
|
250
|
+
#[test]
|
|
251
|
+
fn test_context_new() {
|
|
252
|
+
let ctx = Context::new();
|
|
253
|
+
assert!(!ctx.is_cancelled());
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
#[test]
|
|
257
|
+
fn test_context_cancel() {
|
|
258
|
+
let ctx = Context::new();
|
|
259
|
+
ctx.cancel();
|
|
260
|
+
assert!(ctx.is_cancelled());
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
#[test]
|
|
264
|
+
fn test_context_child() {
|
|
265
|
+
let parent = Context::new();
|
|
266
|
+
let child = parent.child();
|
|
267
|
+
|
|
268
|
+
assert!(!parent.is_cancelled());
|
|
269
|
+
assert!(!child.is_cancelled());
|
|
270
|
+
|
|
271
|
+
parent.cancel();
|
|
272
|
+
|
|
273
|
+
assert!(parent.is_cancelled());
|
|
274
|
+
assert!(child.is_cancelled());
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
#[test]
|
|
278
|
+
fn test_context_child_independent() {
|
|
279
|
+
let parent = Context::new();
|
|
280
|
+
let child = parent.child();
|
|
281
|
+
|
|
282
|
+
// Cancelling child doesn't affect parent
|
|
283
|
+
child.cancel();
|
|
284
|
+
|
|
285
|
+
assert!(!parent.is_cancelled());
|
|
286
|
+
assert!(child.is_cancelled());
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
#[tokio::test]
|
|
290
|
+
async fn test_context_cancelled_future() {
|
|
291
|
+
let ctx = Context::new();
|
|
292
|
+
|
|
293
|
+
// Spawn a task that cancels after a delay
|
|
294
|
+
let ctx_clone = ctx.clone();
|
|
295
|
+
tokio::spawn(async move {
|
|
296
|
+
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
|
297
|
+
ctx_clone.cancel();
|
|
298
|
+
});
|
|
299
|
+
|
|
300
|
+
// Wait for cancellation
|
|
301
|
+
ctx.cancelled().await;
|
|
302
|
+
assert!(ctx.is_cancelled());
|
|
303
|
+
}
|
|
304
|
+
}
|