starpc 0.49.14 → 0.49.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/srpc/lib.rs CHANGED
@@ -75,6 +75,10 @@ pub mod transport;
75
75
 
76
76
  #[cfg(feature = "build")]
77
77
  pub mod build;
78
+ #[cfg(feature = "websocket")]
79
+ pub mod websocket;
80
+ #[cfg(feature = "yamux")]
81
+ pub mod yamux;
78
82
 
79
83
  // Re-exports for convenience.
80
84
  pub use client::{BoxClient, Client, OpenStream, SrpcClient};
@@ -91,6 +95,10 @@ pub use stream::{ArcStream, BoxStream, Context, Stream, StreamExt};
91
95
  pub use transport::{
92
96
  create_packet_channel, decode_optional_data, encode_optional_data, TransportPacketWriter,
93
97
  };
98
+ #[cfg(feature = "websocket")]
99
+ pub use websocket::websocket_byte_stream;
100
+ #[cfg(feature = "yamux")]
101
+ pub use yamux::YamuxStreamOpener;
94
102
 
95
103
  // Re-export async_trait for use in generated code.
96
104
  pub use async_trait::async_trait;
package/srpc/rwc-conn.go CHANGED
@@ -7,6 +7,8 @@ import (
7
7
  "os"
8
8
  "sync"
9
9
  "time"
10
+
11
+ "github.com/aperturerobotics/starpc/internal/contextutil"
10
12
  )
11
13
 
12
14
  // connPktSize is the size of the buffers to use for packets for the RwcConn.
@@ -68,7 +70,7 @@ func NewRwcConn(
68
70
  laddr, raddr net.Addr,
69
71
  bufferPacketN int,
70
72
  ) *RwcConn {
71
- ctx, ctxCancel := context.WithCancel(ctx)
73
+ ctx, ctxCancel := contextutil.WithCancel(ctx)
72
74
  if bufferPacketN <= 0 {
73
75
  bufferPacketN = 10
74
76
  }
@@ -91,13 +91,13 @@ func (r *ServerRPC) HandleCallStart(pkt *CallStart) error {
91
91
  // invokeRPC invokes the RPC after CallStart is received.
92
92
  func (r *ServerRPC) invokeRPC(serviceID, methodID string) {
93
93
  // on the server side, the writer is closed by invokeRPC.
94
- strm := NewMsgStream(r.ctx, r, r.ctxCancel)
94
+ strm := NewMsgStream(r.ctx, r, r.cancelContext)
95
95
  ok, err := r.invoker.InvokeMethod(serviceID, methodID, strm)
96
96
  if err == nil && !ok {
97
97
  err = ErrUnimplemented
98
98
  }
99
+ r.beginLocalCompletion()
99
100
  outPkt := NewCallDataPacket(nil, false, true, err)
100
101
  _ = r.writer.WritePacket(outPkt)
101
- _ = r.writer.Close()
102
- r.ctxCancel()
102
+ r.finishLocalCompletion()
103
103
  }
package/srpc/server.go CHANGED
@@ -3,6 +3,8 @@ package srpc
3
3
  import (
4
4
  "context"
5
5
  "io"
6
+
7
+ "github.com/aperturerobotics/starpc/internal/contextutil"
6
8
  )
7
9
 
8
10
  // Server handles incoming RPC streams with a mux.
@@ -26,7 +28,7 @@ func (s *Server) GetInvoker() Invoker {
26
28
  // HandleStream handles an incoming stream and runs the read loop.
27
29
  // Uses length-prefixed packets.
28
30
  func (s *Server) HandleStream(ctx context.Context, rwc io.ReadWriteCloser) {
29
- subCtx, subCtxCancel := context.WithCancel(ctx)
31
+ subCtx, subCtxCancel := contextutil.WithCancel(ctx)
30
32
  defer subCtxCancel()
31
33
  prw := NewPacketReadWriter(rwc)
32
34
  serverRPC := NewServerRPC(subCtx, s.invoker, prw)
package/srpc/server.rs CHANGED
@@ -104,7 +104,7 @@ impl<I: Invoker + 'static> Server<I> {
104
104
  }
105
105
 
106
106
  /// Reports an error through the error handler, if configured.
107
- fn report_error(&self, err: Error) {
107
+ pub(crate) fn report_error(&self, err: Error) {
108
108
  if let Some(ref handler) = self.error_handler {
109
109
  handler(err);
110
110
  }
@@ -214,6 +214,44 @@ impl<I: Invoker + 'static> Server<I> {
214
214
  Ok(())
215
215
  }
216
216
 
217
+ /// handle_yamux handles a yamux connection by routing each accepted substream as a
218
+ /// Starpc stream.
219
+ #[cfg(feature = "yamux")]
220
+ pub async fn handle_yamux<T>(&self, transport: T) -> Result<()>
221
+ where
222
+ T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
223
+ {
224
+ self.handle_yamux_with_config(transport, yamux::Config::default())
225
+ .await
226
+ }
227
+
228
+ /// handle_yamux_with_config handles a yamux connection with the provided yamux
229
+ /// configuration.
230
+ #[cfg(feature = "yamux")]
231
+ pub async fn handle_yamux_with_config<T>(
232
+ &self,
233
+ transport: T,
234
+ config: yamux::Config,
235
+ ) -> Result<()>
236
+ where
237
+ T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
238
+ {
239
+ crate::yamux::handle_server_connection(self, transport, config).await
240
+ }
241
+
242
+ /// handle_websocket_yamux handles a WebSocket connection carrying yamux substreams.
243
+ #[cfg(all(feature = "websocket", feature = "yamux"))]
244
+ pub async fn handle_websocket_yamux<T>(
245
+ &self,
246
+ socket: tokio_tungstenite::WebSocketStream<T>,
247
+ ) -> Result<()>
248
+ where
249
+ T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
250
+ {
251
+ self.handle_yamux(crate::websocket::websocket_byte_stream(socket))
252
+ .await
253
+ }
254
+
217
255
  /// Accepts and handles connections in a loop.
218
256
  ///
219
257
  /// This is a convenience method that accepts connections from a listener
@@ -246,7 +284,7 @@ impl<I: Invoker + 'static> Server<I> {
246
284
  }
247
285
 
248
286
  /// Creates a clone of the server for spawning tasks.
249
- fn clone_for_spawn(&self) -> Server<I> {
287
+ pub(crate) fn clone_for_spawn(&self) -> Server<I> {
250
288
  Server {
251
289
  invoker: self.invoker.clone(),
252
290
  config: self.config.clone(),
@@ -140,6 +140,94 @@ describe('srpc server', () => {
140
140
  expect([...body.value.data]).toEqual([...response])
141
141
  })
142
142
 
143
+ it('keeps StreamConn server-streaming responses open after Go-style request close', async () => {
144
+ const mux = createMux()
145
+ const response = new TextEncoder().encode('streamconn init')
146
+ mux.registerLookupMethod(async (service, method) => {
147
+ if (service !== 'test.ResourceService' || method !== 'ResourceClient') {
148
+ return null
149
+ }
150
+ return async (_dataSource, dataSink) => {
151
+ await dataSink(
152
+ (async function* () {
153
+ await new Promise((resolve) => setTimeout(resolve, 10))
154
+ yield response
155
+ })(),
156
+ )
157
+ }
158
+ })
159
+
160
+ const server = new Server(mux.lookupMethod)
161
+ const clientConn = new StreamConn()
162
+ const serverConn = new StreamConn(server, { direction: 'inbound' })
163
+ const { port1: clientPort, port2: serverPort } = new MessageChannel()
164
+ const clientChannelStream = new ChannelStream('client', clientPort)
165
+ const serverChannelStream = new ChannelStream('server', serverPort)
166
+
167
+ pipe(
168
+ clientChannelStream,
169
+ clientConn,
170
+ combineUint8ArrayListTransform(),
171
+ clientChannelStream,
172
+ )
173
+ .catch((err: Error) => clientConn.close(err))
174
+ .then(() => clientConn.close())
175
+
176
+ pipe(
177
+ serverChannelStream,
178
+ serverConn,
179
+ combineUint8ArrayListTransform(),
180
+ serverChannelStream,
181
+ )
182
+ .catch((err: Error) => serverConn.close(err))
183
+ .then(() => serverConn.close())
184
+
185
+ try {
186
+ const stream = await clientConn.openStream()
187
+ const firstResponse = (async () => {
188
+ for await (const packetData of stream.source) {
189
+ const packet = Packet.fromBinary(packetData)
190
+ if (packet.body?.case === 'callData') {
191
+ return packet
192
+ }
193
+ }
194
+ throw new Error('server response stream ended before call data')
195
+ })()
196
+
197
+ await stream.sink(
198
+ (async function* () {
199
+ yield Packet.toBinary({
200
+ body: {
201
+ case: 'callStart',
202
+ value: {
203
+ rpcService: 'test.ResourceService',
204
+ rpcMethod: 'ResourceClient',
205
+ data: new Uint8Array(0),
206
+ dataIsZero: true,
207
+ },
208
+ },
209
+ })
210
+ })(),
211
+ )
212
+
213
+ const packet = await promiseWithTimeout(
214
+ firstResponse,
215
+ 'StreamConn server-streaming response',
216
+ )
217
+ const body = packet.body
218
+ expect(body?.case).toBe('callData')
219
+ if (body?.case !== 'callData' || !body.value?.data) {
220
+ throw new Error('expected callData packet')
221
+ }
222
+ expect([...body.value.data]).toEqual([...response])
223
+ } finally {
224
+ clientConn.close()
225
+ serverConn.close()
226
+ clientChannelStream.close()
227
+ serverChannelStream.close()
228
+ }
229
+ })
230
+
143
231
  it('removes abort listeners after a request completes', async () => {
144
232
  const controller = new AbortController()
145
233
  const removeEventListener = vi.spyOn(
@@ -4,6 +4,8 @@ import (
4
4
  "context"
5
5
  "io"
6
6
  "sync"
7
+
8
+ "github.com/aperturerobotics/starpc/internal/contextutil"
7
9
  )
8
10
 
9
11
  // pipeStream implements an in-memory stream.
@@ -22,9 +24,9 @@ type pipeStream struct {
22
24
  // NewPipeStream constructs a new in-memory stream.
23
25
  func NewPipeStream(ctx context.Context) (Stream, Stream) {
24
26
  s1 := &pipeStream{dataCh: make(chan []byte, 5)}
25
- s1.ctx, s1.ctxCancel = context.WithCancel(ctx)
27
+ s1.ctx, s1.ctxCancel = contextutil.WithCancel(ctx)
26
28
  s2 := &pipeStream{other: s1, dataCh: make(chan []byte, 5)}
27
- s2.ctx, s2.ctxCancel = context.WithCancel(ctx)
29
+ s2.ctx, s2.ctxCancel = contextutil.WithCancel(ctx)
28
30
  s1.other = s2
29
31
  return s1, s2
30
32
  }
@@ -0,0 +1,70 @@
1
+ //! WebSocket transport adapters for starpc.
2
+
3
+ use bytes::Bytes;
4
+ use futures::{SinkExt, StreamExt};
5
+ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream};
6
+ use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
7
+
8
+ /// WEBSOCKET_BYTE_STREAM_BUFFER is the duplex buffer used by WebSocket byte adapters.
9
+ pub const WEBSOCKET_BYTE_STREAM_BUFFER: usize = 64 * 1024;
10
+
11
+ /// WEBSOCKET_MESSAGE_BUFFER is the maximum byte chunk sent in one WebSocket message.
12
+ pub const WEBSOCKET_MESSAGE_BUFFER: usize = 16 * 1024;
13
+
14
+ /// websocket_byte_stream adapts a binary WebSocket stream to Tokio byte I/O.
15
+ ///
16
+ /// The adapter treats incoming binary WebSocket messages as consecutive bytes
17
+ /// and writes outbound byte chunks as binary WebSocket messages. Control and
18
+ /// text messages are ignored except that close terminates the byte stream.
19
+ pub fn websocket_byte_stream<S>(socket: WebSocketStream<S>) -> DuplexStream
20
+ where
21
+ S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
22
+ {
23
+ let (mut socket_writer, mut socket_reader) = socket.split();
24
+ let (stream, peer) = tokio::io::duplex(WEBSOCKET_BYTE_STREAM_BUFFER);
25
+ let (mut peer_reader, mut peer_writer) = tokio::io::split(peer);
26
+
27
+ tokio::spawn(async move {
28
+ while let Some(result) = socket_reader.next().await {
29
+ match result {
30
+ Ok(message) if message.is_binary() => {
31
+ let data = message.into_data();
32
+ if peer_writer.write_all(data.as_ref()).await.is_err() {
33
+ break;
34
+ }
35
+ if peer_writer.flush().await.is_err() {
36
+ break;
37
+ }
38
+ }
39
+ Ok(message) if message.is_close() => break,
40
+ Ok(_) => {}
41
+ Err(_) => break,
42
+ }
43
+ }
44
+ let _ = peer_writer.shutdown().await;
45
+ });
46
+
47
+ tokio::spawn(async move {
48
+ let mut buf = vec![0; WEBSOCKET_MESSAGE_BUFFER];
49
+ loop {
50
+ match peer_reader.read(&mut buf).await {
51
+ Ok(0) => {
52
+ let _ = socket_writer.close().await;
53
+ break;
54
+ }
55
+ Ok(n) => {
56
+ let message = Message::Binary(Bytes::copy_from_slice(&buf[..n]));
57
+ if socket_writer.send(message).await.is_err() {
58
+ break;
59
+ }
60
+ }
61
+ Err(_) => {
62
+ let _ = socket_writer.close().await;
63
+ break;
64
+ }
65
+ }
66
+ }
67
+ });
68
+
69
+ stream
70
+ }
package/srpc/yamux.rs ADDED
@@ -0,0 +1,198 @@
1
+ //! Yamux transport adapters for starpc.
2
+
3
+ use async_trait::async_trait;
4
+ use futures::future::poll_fn;
5
+ use std::io;
6
+ use std::pin::Pin;
7
+ use std::sync::Arc;
8
+ use tokio::io::{AsyncRead, AsyncWrite};
9
+ use tokio::sync::{mpsc, oneshot};
10
+ use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
11
+
12
+ use crate::client::{OpenStream, PacketReceiver};
13
+ use crate::error::{Error, Result};
14
+ use crate::invoker::Invoker;
15
+ use crate::rpc::PacketWriter;
16
+ use crate::server::Server;
17
+ use crate::transport::{create_packet_channel, DEFAULT_CHANNEL_BUFFER};
18
+
19
+ type OpenResult = std::result::Result<::yamux::Stream, ::yamux::ConnectionError>;
20
+ type OpenRequest = oneshot::Sender<OpenResult>;
21
+
22
+ /// YamuxStreamOpener opens Starpc packet streams over one yamux connection.
23
+ #[derive(Clone)]
24
+ pub struct YamuxStreamOpener {
25
+ requests: mpsc::Sender<OpenRequest>,
26
+ }
27
+
28
+ impl YamuxStreamOpener {
29
+ /// client creates a client-mode yamux opener over a Tokio transport.
30
+ pub fn client<T>(transport: T) -> Self
31
+ where
32
+ T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
33
+ {
34
+ Self::client_with_config(transport, ::yamux::Config::default())
35
+ }
36
+
37
+ /// client_with_config creates a client-mode yamux opener with a custom config.
38
+ pub fn client_with_config<T>(transport: T, config: ::yamux::Config) -> Self
39
+ where
40
+ T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
41
+ {
42
+ let (requests, request_rx) = mpsc::channel(DEFAULT_CHANNEL_BUFFER);
43
+ spawn_client_driver(transport.compat(), config, request_rx);
44
+ Self { requests }
45
+ }
46
+
47
+ /// client_websocket creates a client-mode yamux opener over a WebSocket.
48
+ #[cfg(feature = "websocket")]
49
+ pub fn client_websocket<S>(socket: tokio_tungstenite::WebSocketStream<S>) -> Self
50
+ where
51
+ S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
52
+ {
53
+ Self::client(crate::websocket::websocket_byte_stream(socket))
54
+ }
55
+
56
+ /// client_websocket_with_config creates a client-mode yamux WebSocket opener with a custom config.
57
+ #[cfg(feature = "websocket")]
58
+ pub fn client_websocket_with_config<S>(
59
+ socket: tokio_tungstenite::WebSocketStream<S>,
60
+ config: ::yamux::Config,
61
+ ) -> Self
62
+ where
63
+ S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
64
+ {
65
+ Self::client_with_config(crate::websocket::websocket_byte_stream(socket), config)
66
+ }
67
+ }
68
+
69
+ #[async_trait]
70
+ impl OpenStream for YamuxStreamOpener {
71
+ async fn open_stream(&self) -> Result<(Arc<dyn PacketWriter>, PacketReceiver)> {
72
+ let (tx, rx) = oneshot::channel();
73
+ self.requests
74
+ .send(tx)
75
+ .await
76
+ .map_err(|_| Error::StreamClosed)?;
77
+
78
+ let stream = rx
79
+ .await
80
+ .map_err(|_| Error::StreamClosed)?
81
+ .map_err(connection_error)?;
82
+ let stream = stream.compat();
83
+ let (read_half, write_half) = tokio::io::split(stream);
84
+ Ok(create_packet_channel(read_half, write_half))
85
+ }
86
+ }
87
+
88
+ /// handle_server_connection accepts yamux streams and serves each as Starpc.
89
+ pub async fn handle_server_connection<I, T>(
90
+ server: &Server<I>,
91
+ transport: T,
92
+ config: ::yamux::Config,
93
+ ) -> Result<()>
94
+ where
95
+ I: Invoker + 'static,
96
+ T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
97
+ {
98
+ let mut connection =
99
+ ::yamux::Connection::new(transport.compat(), config, ::yamux::Mode::Server);
100
+
101
+ loop {
102
+ match poll_fn(|cx| connection.poll_next_inbound(cx)).await {
103
+ Some(Ok(stream)) => {
104
+ let server = server.clone_for_spawn();
105
+ tokio::spawn(async move {
106
+ if let Err(err) = server.handle_stream(stream.compat()).await {
107
+ server.report_error(err);
108
+ }
109
+ });
110
+ }
111
+ Some(Err(err)) => return Err(connection_error(err)),
112
+ None => return Ok(()),
113
+ }
114
+ }
115
+ }
116
+
117
+ enum DriverEvent {
118
+ Opened(OpenRequest, OpenResult),
119
+ Closed {
120
+ pending: Option<OpenRequest>,
121
+ err: Option<::yamux::ConnectionError>,
122
+ },
123
+ }
124
+
125
+ fn spawn_client_driver<T>(
126
+ transport: T,
127
+ config: ::yamux::Config,
128
+ mut requests: mpsc::Receiver<OpenRequest>,
129
+ ) -> tokio::task::JoinHandle<()>
130
+ where
131
+ T: futures::io::AsyncRead + futures::io::AsyncWrite + Send + Unpin + 'static,
132
+ {
133
+ tokio::spawn(async move {
134
+ let mut connection = ::yamux::Connection::new(transport, config, ::yamux::Mode::Client);
135
+ let mut pending = None;
136
+ let mut requests_closed = false;
137
+
138
+ loop {
139
+ let event = poll_fn(|cx| {
140
+ if pending.is_none() && !requests_closed {
141
+ match Pin::new(&mut requests).poll_recv(cx) {
142
+ std::task::Poll::Ready(Some(request)) => pending = Some(request),
143
+ std::task::Poll::Ready(None) => requests_closed = true,
144
+ std::task::Poll::Pending => {}
145
+ }
146
+ }
147
+
148
+ if let Some(request) = pending.take() {
149
+ match connection.poll_new_outbound(cx) {
150
+ std::task::Poll::Ready(result) => {
151
+ return std::task::Poll::Ready(DriverEvent::Opened(request, result));
152
+ }
153
+ std::task::Poll::Pending => pending = Some(request),
154
+ }
155
+ }
156
+
157
+ loop {
158
+ match connection.poll_next_inbound(cx) {
159
+ std::task::Poll::Ready(Some(Ok(_stream))) => continue,
160
+ std::task::Poll::Ready(Some(Err(err))) => {
161
+ return std::task::Poll::Ready(DriverEvent::Closed {
162
+ pending: pending.take(),
163
+ err: Some(err),
164
+ });
165
+ }
166
+ std::task::Poll::Ready(None) => {
167
+ return std::task::Poll::Ready(DriverEvent::Closed {
168
+ pending: pending.take(),
169
+ err: None,
170
+ });
171
+ }
172
+ std::task::Poll::Pending => return std::task::Poll::Pending,
173
+ }
174
+ }
175
+ })
176
+ .await;
177
+
178
+ match event {
179
+ DriverEvent::Opened(request, result) => {
180
+ let _ = request.send(result);
181
+ }
182
+ DriverEvent::Closed { pending, err } => {
183
+ if let Some(request) = pending {
184
+ let _ = request.send(Err(err.unwrap_or(::yamux::ConnectionError::Closed)));
185
+ }
186
+ break;
187
+ }
188
+ }
189
+ }
190
+ })
191
+ }
192
+
193
+ fn connection_error(err: ::yamux::ConnectionError) -> Error {
194
+ match err {
195
+ ::yamux::ConnectionError::Io(err) => Error::Io(err),
196
+ other => Error::Io(io::Error::new(io::ErrorKind::Other, other.to_string())),
197
+ }
198
+ }