starpc 0.49.14 → 0.49.15
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/dist/srpc/server.test.js +65 -0
- package/package.json +1 -1
- package/srpc/client-rpc.go +5 -1
- package/srpc/client.rs +4 -0
- package/srpc/common-rpc.go +19 -8
- package/srpc/common-rpc_test.go +38 -0
- package/srpc/lib.rs +8 -0
- package/srpc/server.rs +40 -2
- package/srpc/server.test.ts +88 -0
- package/srpc/websocket.rs +70 -0
- package/srpc/yamux.rs +198 -0
package/README.md
CHANGED
|
@@ -232,7 +232,7 @@ import { WebSocketConn } from 'starpc'
|
|
|
232
232
|
import { EchoerClient } from './echo/index.js'
|
|
233
233
|
|
|
234
234
|
const ws = new WebSocket('ws://localhost:8080/api')
|
|
235
|
-
const conn = new WebSocketConn(ws)
|
|
235
|
+
const conn = new WebSocketConn(ws, 'outbound')
|
|
236
236
|
const client = conn.buildClient()
|
|
237
237
|
const echoer = new EchoerClient(client)
|
|
238
238
|
|
package/dist/srpc/server.test.js
CHANGED
|
@@ -94,6 +94,71 @@ describe('srpc server', () => {
|
|
|
94
94
|
}
|
|
95
95
|
expect([...body.value.data]).toEqual([...response]);
|
|
96
96
|
});
|
|
97
|
+
it('keeps StreamConn server-streaming responses open after Go-style request close', async () => {
|
|
98
|
+
const mux = createMux();
|
|
99
|
+
const response = new TextEncoder().encode('streamconn init');
|
|
100
|
+
mux.registerLookupMethod(async (service, method) => {
|
|
101
|
+
if (service !== 'test.ResourceService' || method !== 'ResourceClient') {
|
|
102
|
+
return null;
|
|
103
|
+
}
|
|
104
|
+
return async (_dataSource, dataSink) => {
|
|
105
|
+
await dataSink((async function* () {
|
|
106
|
+
await new Promise((resolve) => setTimeout(resolve, 10));
|
|
107
|
+
yield response;
|
|
108
|
+
})());
|
|
109
|
+
};
|
|
110
|
+
});
|
|
111
|
+
const server = new Server(mux.lookupMethod);
|
|
112
|
+
const clientConn = new StreamConn();
|
|
113
|
+
const serverConn = new StreamConn(server, { direction: 'inbound' });
|
|
114
|
+
const { port1: clientPort, port2: serverPort } = new MessageChannel();
|
|
115
|
+
const clientChannelStream = new ChannelStream('client', clientPort);
|
|
116
|
+
const serverChannelStream = new ChannelStream('server', serverPort);
|
|
117
|
+
pipe(clientChannelStream, clientConn, combineUint8ArrayListTransform(), clientChannelStream)
|
|
118
|
+
.catch((err) => clientConn.close(err))
|
|
119
|
+
.then(() => clientConn.close());
|
|
120
|
+
pipe(serverChannelStream, serverConn, combineUint8ArrayListTransform(), serverChannelStream)
|
|
121
|
+
.catch((err) => serverConn.close(err))
|
|
122
|
+
.then(() => serverConn.close());
|
|
123
|
+
try {
|
|
124
|
+
const stream = await clientConn.openStream();
|
|
125
|
+
const firstResponse = (async () => {
|
|
126
|
+
for await (const packetData of stream.source) {
|
|
127
|
+
const packet = Packet.fromBinary(packetData);
|
|
128
|
+
if (packet.body?.case === 'callData') {
|
|
129
|
+
return packet;
|
|
130
|
+
}
|
|
131
|
+
}
|
|
132
|
+
throw new Error('server response stream ended before call data');
|
|
133
|
+
})();
|
|
134
|
+
await stream.sink((async function* () {
|
|
135
|
+
yield Packet.toBinary({
|
|
136
|
+
body: {
|
|
137
|
+
case: 'callStart',
|
|
138
|
+
value: {
|
|
139
|
+
rpcService: 'test.ResourceService',
|
|
140
|
+
rpcMethod: 'ResourceClient',
|
|
141
|
+
data: new Uint8Array(0),
|
|
142
|
+
dataIsZero: true,
|
|
143
|
+
},
|
|
144
|
+
},
|
|
145
|
+
});
|
|
146
|
+
})());
|
|
147
|
+
const packet = await promiseWithTimeout(firstResponse, 'StreamConn server-streaming response');
|
|
148
|
+
const body = packet.body;
|
|
149
|
+
expect(body?.case).toBe('callData');
|
|
150
|
+
if (body?.case !== 'callData' || !body.value?.data) {
|
|
151
|
+
throw new Error('expected callData packet');
|
|
152
|
+
}
|
|
153
|
+
expect([...body.value.data]).toEqual([...response]);
|
|
154
|
+
}
|
|
155
|
+
finally {
|
|
156
|
+
clientConn.close();
|
|
157
|
+
serverConn.close();
|
|
158
|
+
clientChannelStream.close();
|
|
159
|
+
serverChannelStream.close();
|
|
160
|
+
}
|
|
161
|
+
});
|
|
97
162
|
it('removes abort listeners after a request completes', async () => {
|
|
98
163
|
const controller = new AbortController();
|
|
99
164
|
const removeEventListener = vi.spyOn(controller.signal, 'removeEventListener');
|
package/package.json
CHANGED
package/srpc/client-rpc.go
CHANGED
|
@@ -109,10 +109,14 @@ func (r *ClientRPC) HandleCallStart(pkt *CallStart) error {
|
|
|
109
109
|
// Close releases any resources held by the ClientRPC.
|
|
110
110
|
func (r *ClientRPC) Close() {
|
|
111
111
|
locked := r.bcast.Lock()
|
|
112
|
+
var writer PacketWriter
|
|
112
113
|
// call did not start yet if writer is nil.
|
|
113
114
|
if r.writer != nil {
|
|
114
115
|
_ = r.WriteCallCancel()
|
|
115
|
-
r.closeLocked(&locked)
|
|
116
|
+
writer = r.closeLocked(&locked)
|
|
116
117
|
}
|
|
117
118
|
locked.Unlock()
|
|
119
|
+
if writer != nil {
|
|
120
|
+
_ = writer.Close()
|
|
121
|
+
}
|
|
118
122
|
}
|
package/srpc/client.rs
CHANGED
|
@@ -262,6 +262,10 @@ pub mod transport {
|
|
|
262
262
|
|
|
263
263
|
/// Re-export TransportPacketWriter for direct use.
|
|
264
264
|
pub use crate::transport::TransportPacketWriter;
|
|
265
|
+
|
|
266
|
+
/// Re-export YamuxStreamOpener for multiplexed transports.
|
|
267
|
+
#[cfg(feature = "yamux")]
|
|
268
|
+
pub use crate::yamux::YamuxStreamOpener;
|
|
265
269
|
}
|
|
266
270
|
|
|
267
271
|
#[cfg(test)]
|
package/srpc/common-rpc.go
CHANGED
|
@@ -26,6 +26,8 @@ type commonRPC struct {
|
|
|
26
26
|
bcast broadcast.Broadcast
|
|
27
27
|
// writer is the writer to write messages to
|
|
28
28
|
writer PacketWriter
|
|
29
|
+
// writerClosed is set after writer has been closed locally.
|
|
30
|
+
writerClosed bool
|
|
29
31
|
// dataQueue contains incoming data packets.
|
|
30
32
|
// note: packets may be len() == 0
|
|
31
33
|
dataQueue [][]byte
|
|
@@ -86,8 +88,11 @@ func (c *commonRPC) ReadOne() ([]byte, error) {
|
|
|
86
88
|
locked := c.bcast.Lock()
|
|
87
89
|
if ctxDone && !c.dataClosed {
|
|
88
90
|
// context must have been canceled locally
|
|
89
|
-
c.closeLocked(&locked)
|
|
91
|
+
writer := c.closeLocked(&locked)
|
|
90
92
|
locked.Unlock()
|
|
93
|
+
if writer != nil {
|
|
94
|
+
_ = writer.Close()
|
|
95
|
+
}
|
|
91
96
|
return nil, context.Canceled
|
|
92
97
|
}
|
|
93
98
|
|
|
@@ -144,7 +149,7 @@ func (c *commonRPC) WriteCallData(data []byte, dataIsZero, complete bool, err er
|
|
|
144
149
|
func (c *commonRPC) HandleStreamClose(closeErr error) {
|
|
145
150
|
var writer PacketWriter
|
|
146
151
|
locked := c.bcast.Lock()
|
|
147
|
-
if c.dataClosed {
|
|
152
|
+
if c.dataClosed && c.writerClosed {
|
|
148
153
|
locked.Unlock()
|
|
149
154
|
return
|
|
150
155
|
}
|
|
@@ -153,7 +158,7 @@ func (c *commonRPC) HandleStreamClose(closeErr error) {
|
|
|
153
158
|
}
|
|
154
159
|
c.dataClosed = true
|
|
155
160
|
c.ctxCancel()
|
|
156
|
-
writer = c.
|
|
161
|
+
writer = c.closeWriterLocked()
|
|
157
162
|
locked.Broadcast()
|
|
158
163
|
locked.Unlock()
|
|
159
164
|
if writer != nil {
|
|
@@ -212,16 +217,22 @@ func (c *commonRPC) WriteCallCancel() error {
|
|
|
212
217
|
}
|
|
213
218
|
|
|
214
219
|
// closeLocked releases resources held by the RPC.
|
|
215
|
-
func (c *commonRPC) closeLocked(locked *broadcast.Locked) {
|
|
216
|
-
if c.dataClosed {
|
|
217
|
-
return
|
|
218
|
-
}
|
|
220
|
+
func (c *commonRPC) closeLocked(locked *broadcast.Locked) PacketWriter {
|
|
219
221
|
c.dataClosed = true
|
|
220
222
|
c.localCompleted.Store(true)
|
|
221
223
|
if c.remoteErr == nil {
|
|
222
224
|
c.remoteErr = context.Canceled
|
|
223
225
|
}
|
|
224
|
-
|
|
226
|
+
writer := c.closeWriterLocked()
|
|
225
227
|
locked.Broadcast()
|
|
226
228
|
c.ctxCancel()
|
|
229
|
+
return writer
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
func (c *commonRPC) closeWriterLocked() PacketWriter {
|
|
233
|
+
if c.writerClosed || c.writer == nil {
|
|
234
|
+
return nil
|
|
235
|
+
}
|
|
236
|
+
c.writerClosed = true
|
|
237
|
+
return c.writer
|
|
227
238
|
}
|
package/srpc/common-rpc_test.go
CHANGED
|
@@ -69,6 +69,44 @@ func TestCommonRPCHandleStreamCloseClosesWriterOutsideBroadcastLock(t *testing.T
|
|
|
69
69
|
}
|
|
70
70
|
}
|
|
71
71
|
|
|
72
|
+
func TestClientRPCCloseAfterRemoteCompleteClosesWriter(t *testing.T) {
|
|
73
|
+
writer := &closeCountingPacketWriter{}
|
|
74
|
+
rpc := NewClientRPC(context.Background(), "service", "method")
|
|
75
|
+
if err := rpc.Start(writer, false, nil); err != nil {
|
|
76
|
+
t.Fatalf("start: %v", err)
|
|
77
|
+
}
|
|
78
|
+
if err := rpc.HandleCallData(NewCallDataPacket([]byte("ok"), false, true, nil).GetCallData()); err != nil {
|
|
79
|
+
t.Fatalf("handle call data: %v", err)
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
rpc.Close()
|
|
83
|
+
|
|
84
|
+
if got := writer.closed.Load(); got != 1 {
|
|
85
|
+
t.Fatalf("expected writer closed once, got %d", got)
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
func TestClientRPCCloseClosesWriterOutsideBroadcastLock(t *testing.T) {
|
|
90
|
+
var rpc *ClientRPC
|
|
91
|
+
writerClosedOutsideLock := false
|
|
92
|
+
writer := &closeCallbackPacketWriter{
|
|
93
|
+
closeFn: func() {
|
|
94
|
+
ok := rpc.bcast.TryHoldLock(func(func(), func() <-chan struct{}) {})
|
|
95
|
+
writerClosedOutsideLock = ok
|
|
96
|
+
},
|
|
97
|
+
}
|
|
98
|
+
rpc = NewClientRPC(context.Background(), "service", "method")
|
|
99
|
+
if err := rpc.Start(writer, false, nil); err != nil {
|
|
100
|
+
t.Fatalf("start: %v", err)
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
rpc.Close()
|
|
104
|
+
|
|
105
|
+
if !writerClosedOutsideLock {
|
|
106
|
+
t.Fatal("expected writer close outside broadcast lock")
|
|
107
|
+
}
|
|
108
|
+
}
|
|
109
|
+
|
|
72
110
|
func TestCommonRPCReadOneQueuedDoesNotAllocate(t *testing.T) {
|
|
73
111
|
msg := []byte("message")
|
|
74
112
|
queue := make([][]byte, 1)
|
package/srpc/lib.rs
CHANGED
|
@@ -75,6 +75,10 @@ pub mod transport;
|
|
|
75
75
|
|
|
76
76
|
#[cfg(feature = "build")]
|
|
77
77
|
pub mod build;
|
|
78
|
+
#[cfg(feature = "websocket")]
|
|
79
|
+
pub mod websocket;
|
|
80
|
+
#[cfg(feature = "yamux")]
|
|
81
|
+
pub mod yamux;
|
|
78
82
|
|
|
79
83
|
// Re-exports for convenience.
|
|
80
84
|
pub use client::{BoxClient, Client, OpenStream, SrpcClient};
|
|
@@ -91,6 +95,10 @@ pub use stream::{ArcStream, BoxStream, Context, Stream, StreamExt};
|
|
|
91
95
|
pub use transport::{
|
|
92
96
|
create_packet_channel, decode_optional_data, encode_optional_data, TransportPacketWriter,
|
|
93
97
|
};
|
|
98
|
+
#[cfg(feature = "websocket")]
|
|
99
|
+
pub use websocket::websocket_byte_stream;
|
|
100
|
+
#[cfg(feature = "yamux")]
|
|
101
|
+
pub use yamux::YamuxStreamOpener;
|
|
94
102
|
|
|
95
103
|
// Re-export async_trait for use in generated code.
|
|
96
104
|
pub use async_trait::async_trait;
|
package/srpc/server.rs
CHANGED
|
@@ -104,7 +104,7 @@ impl<I: Invoker + 'static> Server<I> {
|
|
|
104
104
|
}
|
|
105
105
|
|
|
106
106
|
/// Reports an error through the error handler, if configured.
|
|
107
|
-
fn report_error(&self, err: Error) {
|
|
107
|
+
pub(crate) fn report_error(&self, err: Error) {
|
|
108
108
|
if let Some(ref handler) = self.error_handler {
|
|
109
109
|
handler(err);
|
|
110
110
|
}
|
|
@@ -214,6 +214,44 @@ impl<I: Invoker + 'static> Server<I> {
|
|
|
214
214
|
Ok(())
|
|
215
215
|
}
|
|
216
216
|
|
|
217
|
+
/// handle_yamux handles a yamux connection by routing each accepted substream as a
|
|
218
|
+
/// Starpc stream.
|
|
219
|
+
#[cfg(feature = "yamux")]
|
|
220
|
+
pub async fn handle_yamux<T>(&self, transport: T) -> Result<()>
|
|
221
|
+
where
|
|
222
|
+
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
|
223
|
+
{
|
|
224
|
+
self.handle_yamux_with_config(transport, yamux::Config::default())
|
|
225
|
+
.await
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
/// handle_yamux_with_config handles a yamux connection with the provided yamux
|
|
229
|
+
/// configuration.
|
|
230
|
+
#[cfg(feature = "yamux")]
|
|
231
|
+
pub async fn handle_yamux_with_config<T>(
|
|
232
|
+
&self,
|
|
233
|
+
transport: T,
|
|
234
|
+
config: yamux::Config,
|
|
235
|
+
) -> Result<()>
|
|
236
|
+
where
|
|
237
|
+
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
|
238
|
+
{
|
|
239
|
+
crate::yamux::handle_server_connection(self, transport, config).await
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
/// handle_websocket_yamux handles a WebSocket connection carrying yamux substreams.
|
|
243
|
+
#[cfg(all(feature = "websocket", feature = "yamux"))]
|
|
244
|
+
pub async fn handle_websocket_yamux<T>(
|
|
245
|
+
&self,
|
|
246
|
+
socket: tokio_tungstenite::WebSocketStream<T>,
|
|
247
|
+
) -> Result<()>
|
|
248
|
+
where
|
|
249
|
+
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
|
250
|
+
{
|
|
251
|
+
self.handle_yamux(crate::websocket::websocket_byte_stream(socket))
|
|
252
|
+
.await
|
|
253
|
+
}
|
|
254
|
+
|
|
217
255
|
/// Accepts and handles connections in a loop.
|
|
218
256
|
///
|
|
219
257
|
/// This is a convenience method that accepts connections from a listener
|
|
@@ -246,7 +284,7 @@ impl<I: Invoker + 'static> Server<I> {
|
|
|
246
284
|
}
|
|
247
285
|
|
|
248
286
|
/// Creates a clone of the server for spawning tasks.
|
|
249
|
-
fn clone_for_spawn(&self) -> Server<I> {
|
|
287
|
+
pub(crate) fn clone_for_spawn(&self) -> Server<I> {
|
|
250
288
|
Server {
|
|
251
289
|
invoker: self.invoker.clone(),
|
|
252
290
|
config: self.config.clone(),
|
package/srpc/server.test.ts
CHANGED
|
@@ -140,6 +140,94 @@ describe('srpc server', () => {
|
|
|
140
140
|
expect([...body.value.data]).toEqual([...response])
|
|
141
141
|
})
|
|
142
142
|
|
|
143
|
+
it('keeps StreamConn server-streaming responses open after Go-style request close', async () => {
|
|
144
|
+
const mux = createMux()
|
|
145
|
+
const response = new TextEncoder().encode('streamconn init')
|
|
146
|
+
mux.registerLookupMethod(async (service, method) => {
|
|
147
|
+
if (service !== 'test.ResourceService' || method !== 'ResourceClient') {
|
|
148
|
+
return null
|
|
149
|
+
}
|
|
150
|
+
return async (_dataSource, dataSink) => {
|
|
151
|
+
await dataSink(
|
|
152
|
+
(async function* () {
|
|
153
|
+
await new Promise((resolve) => setTimeout(resolve, 10))
|
|
154
|
+
yield response
|
|
155
|
+
})(),
|
|
156
|
+
)
|
|
157
|
+
}
|
|
158
|
+
})
|
|
159
|
+
|
|
160
|
+
const server = new Server(mux.lookupMethod)
|
|
161
|
+
const clientConn = new StreamConn()
|
|
162
|
+
const serverConn = new StreamConn(server, { direction: 'inbound' })
|
|
163
|
+
const { port1: clientPort, port2: serverPort } = new MessageChannel()
|
|
164
|
+
const clientChannelStream = new ChannelStream('client', clientPort)
|
|
165
|
+
const serverChannelStream = new ChannelStream('server', serverPort)
|
|
166
|
+
|
|
167
|
+
pipe(
|
|
168
|
+
clientChannelStream,
|
|
169
|
+
clientConn,
|
|
170
|
+
combineUint8ArrayListTransform(),
|
|
171
|
+
clientChannelStream,
|
|
172
|
+
)
|
|
173
|
+
.catch((err: Error) => clientConn.close(err))
|
|
174
|
+
.then(() => clientConn.close())
|
|
175
|
+
|
|
176
|
+
pipe(
|
|
177
|
+
serverChannelStream,
|
|
178
|
+
serverConn,
|
|
179
|
+
combineUint8ArrayListTransform(),
|
|
180
|
+
serverChannelStream,
|
|
181
|
+
)
|
|
182
|
+
.catch((err: Error) => serverConn.close(err))
|
|
183
|
+
.then(() => serverConn.close())
|
|
184
|
+
|
|
185
|
+
try {
|
|
186
|
+
const stream = await clientConn.openStream()
|
|
187
|
+
const firstResponse = (async () => {
|
|
188
|
+
for await (const packetData of stream.source) {
|
|
189
|
+
const packet = Packet.fromBinary(packetData)
|
|
190
|
+
if (packet.body?.case === 'callData') {
|
|
191
|
+
return packet
|
|
192
|
+
}
|
|
193
|
+
}
|
|
194
|
+
throw new Error('server response stream ended before call data')
|
|
195
|
+
})()
|
|
196
|
+
|
|
197
|
+
await stream.sink(
|
|
198
|
+
(async function* () {
|
|
199
|
+
yield Packet.toBinary({
|
|
200
|
+
body: {
|
|
201
|
+
case: 'callStart',
|
|
202
|
+
value: {
|
|
203
|
+
rpcService: 'test.ResourceService',
|
|
204
|
+
rpcMethod: 'ResourceClient',
|
|
205
|
+
data: new Uint8Array(0),
|
|
206
|
+
dataIsZero: true,
|
|
207
|
+
},
|
|
208
|
+
},
|
|
209
|
+
})
|
|
210
|
+
})(),
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
const packet = await promiseWithTimeout(
|
|
214
|
+
firstResponse,
|
|
215
|
+
'StreamConn server-streaming response',
|
|
216
|
+
)
|
|
217
|
+
const body = packet.body
|
|
218
|
+
expect(body?.case).toBe('callData')
|
|
219
|
+
if (body?.case !== 'callData' || !body.value?.data) {
|
|
220
|
+
throw new Error('expected callData packet')
|
|
221
|
+
}
|
|
222
|
+
expect([...body.value.data]).toEqual([...response])
|
|
223
|
+
} finally {
|
|
224
|
+
clientConn.close()
|
|
225
|
+
serverConn.close()
|
|
226
|
+
clientChannelStream.close()
|
|
227
|
+
serverChannelStream.close()
|
|
228
|
+
}
|
|
229
|
+
})
|
|
230
|
+
|
|
143
231
|
it('removes abort listeners after a request completes', async () => {
|
|
144
232
|
const controller = new AbortController()
|
|
145
233
|
const removeEventListener = vi.spyOn(
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
//! WebSocket transport adapters for starpc.
|
|
2
|
+
|
|
3
|
+
use bytes::Bytes;
|
|
4
|
+
use futures::{SinkExt, StreamExt};
|
|
5
|
+
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream};
|
|
6
|
+
use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
|
|
7
|
+
|
|
8
|
+
/// WEBSOCKET_BYTE_STREAM_BUFFER is the duplex buffer used by WebSocket byte adapters.
|
|
9
|
+
pub const WEBSOCKET_BYTE_STREAM_BUFFER: usize = 64 * 1024;
|
|
10
|
+
|
|
11
|
+
/// WEBSOCKET_MESSAGE_BUFFER is the maximum byte chunk sent in one WebSocket message.
|
|
12
|
+
pub const WEBSOCKET_MESSAGE_BUFFER: usize = 16 * 1024;
|
|
13
|
+
|
|
14
|
+
/// websocket_byte_stream adapts a binary WebSocket stream to Tokio byte I/O.
|
|
15
|
+
///
|
|
16
|
+
/// The adapter treats incoming binary WebSocket messages as consecutive bytes
|
|
17
|
+
/// and writes outbound byte chunks as binary WebSocket messages. Control and
|
|
18
|
+
/// text messages are ignored except that close terminates the byte stream.
|
|
19
|
+
pub fn websocket_byte_stream<S>(socket: WebSocketStream<S>) -> DuplexStream
|
|
20
|
+
where
|
|
21
|
+
S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
|
22
|
+
{
|
|
23
|
+
let (mut socket_writer, mut socket_reader) = socket.split();
|
|
24
|
+
let (stream, peer) = tokio::io::duplex(WEBSOCKET_BYTE_STREAM_BUFFER);
|
|
25
|
+
let (mut peer_reader, mut peer_writer) = tokio::io::split(peer);
|
|
26
|
+
|
|
27
|
+
tokio::spawn(async move {
|
|
28
|
+
while let Some(result) = socket_reader.next().await {
|
|
29
|
+
match result {
|
|
30
|
+
Ok(message) if message.is_binary() => {
|
|
31
|
+
let data = message.into_data();
|
|
32
|
+
if peer_writer.write_all(data.as_ref()).await.is_err() {
|
|
33
|
+
break;
|
|
34
|
+
}
|
|
35
|
+
if peer_writer.flush().await.is_err() {
|
|
36
|
+
break;
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
Ok(message) if message.is_close() => break,
|
|
40
|
+
Ok(_) => {}
|
|
41
|
+
Err(_) => break,
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
let _ = peer_writer.shutdown().await;
|
|
45
|
+
});
|
|
46
|
+
|
|
47
|
+
tokio::spawn(async move {
|
|
48
|
+
let mut buf = vec![0; WEBSOCKET_MESSAGE_BUFFER];
|
|
49
|
+
loop {
|
|
50
|
+
match peer_reader.read(&mut buf).await {
|
|
51
|
+
Ok(0) => {
|
|
52
|
+
let _ = socket_writer.close().await;
|
|
53
|
+
break;
|
|
54
|
+
}
|
|
55
|
+
Ok(n) => {
|
|
56
|
+
let message = Message::Binary(Bytes::copy_from_slice(&buf[..n]));
|
|
57
|
+
if socket_writer.send(message).await.is_err() {
|
|
58
|
+
break;
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
Err(_) => {
|
|
62
|
+
let _ = socket_writer.close().await;
|
|
63
|
+
break;
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
});
|
|
68
|
+
|
|
69
|
+
stream
|
|
70
|
+
}
|
package/srpc/yamux.rs
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
//! Yamux transport adapters for starpc.
|
|
2
|
+
|
|
3
|
+
use async_trait::async_trait;
|
|
4
|
+
use futures::future::poll_fn;
|
|
5
|
+
use std::io;
|
|
6
|
+
use std::pin::Pin;
|
|
7
|
+
use std::sync::Arc;
|
|
8
|
+
use tokio::io::{AsyncRead, AsyncWrite};
|
|
9
|
+
use tokio::sync::{mpsc, oneshot};
|
|
10
|
+
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
|
|
11
|
+
|
|
12
|
+
use crate::client::{OpenStream, PacketReceiver};
|
|
13
|
+
use crate::error::{Error, Result};
|
|
14
|
+
use crate::invoker::Invoker;
|
|
15
|
+
use crate::rpc::PacketWriter;
|
|
16
|
+
use crate::server::Server;
|
|
17
|
+
use crate::transport::{create_packet_channel, DEFAULT_CHANNEL_BUFFER};
|
|
18
|
+
|
|
19
|
+
type OpenResult = std::result::Result<::yamux::Stream, ::yamux::ConnectionError>;
|
|
20
|
+
type OpenRequest = oneshot::Sender<OpenResult>;
|
|
21
|
+
|
|
22
|
+
/// YamuxStreamOpener opens Starpc packet streams over one yamux connection.
|
|
23
|
+
#[derive(Clone)]
|
|
24
|
+
pub struct YamuxStreamOpener {
|
|
25
|
+
requests: mpsc::Sender<OpenRequest>,
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
impl YamuxStreamOpener {
|
|
29
|
+
/// client creates a client-mode yamux opener over a Tokio transport.
|
|
30
|
+
pub fn client<T>(transport: T) -> Self
|
|
31
|
+
where
|
|
32
|
+
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
|
33
|
+
{
|
|
34
|
+
Self::client_with_config(transport, ::yamux::Config::default())
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
/// client_with_config creates a client-mode yamux opener with a custom config.
|
|
38
|
+
pub fn client_with_config<T>(transport: T, config: ::yamux::Config) -> Self
|
|
39
|
+
where
|
|
40
|
+
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
|
41
|
+
{
|
|
42
|
+
let (requests, request_rx) = mpsc::channel(DEFAULT_CHANNEL_BUFFER);
|
|
43
|
+
spawn_client_driver(transport.compat(), config, request_rx);
|
|
44
|
+
Self { requests }
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
/// client_websocket creates a client-mode yamux opener over a WebSocket.
|
|
48
|
+
#[cfg(feature = "websocket")]
|
|
49
|
+
pub fn client_websocket<S>(socket: tokio_tungstenite::WebSocketStream<S>) -> Self
|
|
50
|
+
where
|
|
51
|
+
S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
|
52
|
+
{
|
|
53
|
+
Self::client(crate::websocket::websocket_byte_stream(socket))
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
/// client_websocket_with_config creates a client-mode yamux WebSocket opener with a custom config.
|
|
57
|
+
#[cfg(feature = "websocket")]
|
|
58
|
+
pub fn client_websocket_with_config<S>(
|
|
59
|
+
socket: tokio_tungstenite::WebSocketStream<S>,
|
|
60
|
+
config: ::yamux::Config,
|
|
61
|
+
) -> Self
|
|
62
|
+
where
|
|
63
|
+
S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
|
64
|
+
{
|
|
65
|
+
Self::client_with_config(crate::websocket::websocket_byte_stream(socket), config)
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
#[async_trait]
|
|
70
|
+
impl OpenStream for YamuxStreamOpener {
|
|
71
|
+
async fn open_stream(&self) -> Result<(Arc<dyn PacketWriter>, PacketReceiver)> {
|
|
72
|
+
let (tx, rx) = oneshot::channel();
|
|
73
|
+
self.requests
|
|
74
|
+
.send(tx)
|
|
75
|
+
.await
|
|
76
|
+
.map_err(|_| Error::StreamClosed)?;
|
|
77
|
+
|
|
78
|
+
let stream = rx
|
|
79
|
+
.await
|
|
80
|
+
.map_err(|_| Error::StreamClosed)?
|
|
81
|
+
.map_err(connection_error)?;
|
|
82
|
+
let stream = stream.compat();
|
|
83
|
+
let (read_half, write_half) = tokio::io::split(stream);
|
|
84
|
+
Ok(create_packet_channel(read_half, write_half))
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
/// handle_server_connection accepts yamux streams and serves each as Starpc.
|
|
89
|
+
pub async fn handle_server_connection<I, T>(
|
|
90
|
+
server: &Server<I>,
|
|
91
|
+
transport: T,
|
|
92
|
+
config: ::yamux::Config,
|
|
93
|
+
) -> Result<()>
|
|
94
|
+
where
|
|
95
|
+
I: Invoker + 'static,
|
|
96
|
+
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
|
97
|
+
{
|
|
98
|
+
let mut connection =
|
|
99
|
+
::yamux::Connection::new(transport.compat(), config, ::yamux::Mode::Server);
|
|
100
|
+
|
|
101
|
+
loop {
|
|
102
|
+
match poll_fn(|cx| connection.poll_next_inbound(cx)).await {
|
|
103
|
+
Some(Ok(stream)) => {
|
|
104
|
+
let server = server.clone_for_spawn();
|
|
105
|
+
tokio::spawn(async move {
|
|
106
|
+
if let Err(err) = server.handle_stream(stream.compat()).await {
|
|
107
|
+
server.report_error(err);
|
|
108
|
+
}
|
|
109
|
+
});
|
|
110
|
+
}
|
|
111
|
+
Some(Err(err)) => return Err(connection_error(err)),
|
|
112
|
+
None => return Ok(()),
|
|
113
|
+
}
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
enum DriverEvent {
|
|
118
|
+
Opened(OpenRequest, OpenResult),
|
|
119
|
+
Closed {
|
|
120
|
+
pending: Option<OpenRequest>,
|
|
121
|
+
err: Option<::yamux::ConnectionError>,
|
|
122
|
+
},
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
fn spawn_client_driver<T>(
|
|
126
|
+
transport: T,
|
|
127
|
+
config: ::yamux::Config,
|
|
128
|
+
mut requests: mpsc::Receiver<OpenRequest>,
|
|
129
|
+
) -> tokio::task::JoinHandle<()>
|
|
130
|
+
where
|
|
131
|
+
T: futures::io::AsyncRead + futures::io::AsyncWrite + Send + Unpin + 'static,
|
|
132
|
+
{
|
|
133
|
+
tokio::spawn(async move {
|
|
134
|
+
let mut connection = ::yamux::Connection::new(transport, config, ::yamux::Mode::Client);
|
|
135
|
+
let mut pending = None;
|
|
136
|
+
let mut requests_closed = false;
|
|
137
|
+
|
|
138
|
+
loop {
|
|
139
|
+
let event = poll_fn(|cx| {
|
|
140
|
+
if pending.is_none() && !requests_closed {
|
|
141
|
+
match Pin::new(&mut requests).poll_recv(cx) {
|
|
142
|
+
std::task::Poll::Ready(Some(request)) => pending = Some(request),
|
|
143
|
+
std::task::Poll::Ready(None) => requests_closed = true,
|
|
144
|
+
std::task::Poll::Pending => {}
|
|
145
|
+
}
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
if let Some(request) = pending.take() {
|
|
149
|
+
match connection.poll_new_outbound(cx) {
|
|
150
|
+
std::task::Poll::Ready(result) => {
|
|
151
|
+
return std::task::Poll::Ready(DriverEvent::Opened(request, result));
|
|
152
|
+
}
|
|
153
|
+
std::task::Poll::Pending => pending = Some(request),
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
loop {
|
|
158
|
+
match connection.poll_next_inbound(cx) {
|
|
159
|
+
std::task::Poll::Ready(Some(Ok(_stream))) => continue,
|
|
160
|
+
std::task::Poll::Ready(Some(Err(err))) => {
|
|
161
|
+
return std::task::Poll::Ready(DriverEvent::Closed {
|
|
162
|
+
pending: pending.take(),
|
|
163
|
+
err: Some(err),
|
|
164
|
+
});
|
|
165
|
+
}
|
|
166
|
+
std::task::Poll::Ready(None) => {
|
|
167
|
+
return std::task::Poll::Ready(DriverEvent::Closed {
|
|
168
|
+
pending: pending.take(),
|
|
169
|
+
err: None,
|
|
170
|
+
});
|
|
171
|
+
}
|
|
172
|
+
std::task::Poll::Pending => return std::task::Poll::Pending,
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
})
|
|
176
|
+
.await;
|
|
177
|
+
|
|
178
|
+
match event {
|
|
179
|
+
DriverEvent::Opened(request, result) => {
|
|
180
|
+
let _ = request.send(result);
|
|
181
|
+
}
|
|
182
|
+
DriverEvent::Closed { pending, err } => {
|
|
183
|
+
if let Some(request) = pending {
|
|
184
|
+
let _ = request.send(Err(err.unwrap_or(::yamux::ConnectionError::Closed)));
|
|
185
|
+
}
|
|
186
|
+
break;
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
})
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
fn connection_error(err: ::yamux::ConnectionError) -> Error {
|
|
194
|
+
match err {
|
|
195
|
+
::yamux::ConnectionError::Io(err) => Error::Io(err),
|
|
196
|
+
other => Error::Io(io::Error::new(io::ErrorKind::Other, other.to_string())),
|
|
197
|
+
}
|
|
198
|
+
}
|