starpc 0.41.2 → 0.43.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +101 -20
- package/dist/mock/mock.pb.d.ts +1 -1
- package/dist/mock/mock.pb.js +4 -5
- package/dist/mock/mock_srpc.pb.d.ts +3 -3
- package/dist/mock/mock_srpc.pb.js +5 -5
- package/echo/Cargo.toml +21 -0
- package/echo/build.rs +15 -0
- package/echo/echo.pb.cc +405 -0
- package/echo/echo.pb.go +9 -27
- package/echo/echo.pb.h +364 -0
- package/echo/echo_srpc.pb.go +1 -1
- package/echo/gen/mod.rs +3 -0
- package/echo/main.rs +162 -0
- package/go.mod +18 -10
- package/go.sum +28 -18
- package/mock/mock.pb.cc +394 -0
- package/mock/mock.pb.go +9 -27
- package/mock/mock.pb.h +366 -0
- package/mock/mock.pb.ts +11 -13
- package/mock/mock_srpc.pb.go +1 -1
- package/mock/mock_srpc.pb.ts +12 -9
- package/package.json +27 -25
- package/srpc/Cargo.toml +26 -0
- package/srpc/build.rs +15 -0
- package/srpc/client.rs +356 -0
- package/srpc/codec.rs +225 -0
- package/srpc/error.rs +177 -0
- package/srpc/handler.rs +163 -0
- package/srpc/invoker.rs +192 -0
- package/srpc/lib.rs +107 -0
- package/srpc/message.rs +9 -0
- package/srpc/mux.rs +353 -0
- package/srpc/packet.rs +334 -0
- package/srpc/proto/mod.rs +10 -0
- package/srpc/rpc.rs +777 -0
- package/srpc/rpcproto.pb.cc +1381 -0
- package/srpc/rpcproto.pb.go +75 -183
- package/srpc/rpcproto.pb.h +1451 -0
- package/srpc/server.rs +337 -0
- package/srpc/stream.rs +304 -0
- package/srpc/testing.rs +290 -0
- package/srpc/tests/integration_test.rs +495 -0
- package/srpc/transport.rs +218 -0
- package/Makefile +0 -154
package/srpc/rpc.rs
ADDED
|
@@ -0,0 +1,777 @@
|
|
|
1
|
+
//! RPC state machines for client and server.
|
|
2
|
+
//!
|
|
3
|
+
//! This module provides the core RPC state machines that manage the lifecycle
|
|
4
|
+
//! of RPC calls, matching the behavior of the Go and TypeScript implementations.
|
|
5
|
+
|
|
6
|
+
use async_trait::async_trait;
|
|
7
|
+
use bytes::Bytes;
|
|
8
|
+
use std::collections::VecDeque;
|
|
9
|
+
use std::sync::atomic::{AtomicBool, Ordering};
|
|
10
|
+
use std::sync::Arc;
|
|
11
|
+
use tokio::sync::{Mutex, Notify};
|
|
12
|
+
|
|
13
|
+
use crate::error::{Error, Result};
|
|
14
|
+
use crate::packet::{new_call_cancel, new_call_data_full, new_call_start, Validate};
|
|
15
|
+
use crate::proto::{packet::Body, CallData, CallStart, Packet};
|
|
16
|
+
use crate::stream::{Context, Stream};
|
|
17
|
+
use crate::transport::decode_optional_data;
|
|
18
|
+
|
|
19
|
+
/// Trait for writing packets to the transport.
|
|
20
|
+
#[async_trait]
|
|
21
|
+
pub trait PacketWriter: Send + Sync {
|
|
22
|
+
/// Writes a packet to the transport.
|
|
23
|
+
async fn write_packet(&self, packet: Packet) -> Result<()>;
|
|
24
|
+
|
|
25
|
+
/// Closes the writer.
|
|
26
|
+
async fn close(&self) -> Result<()>;
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
/// Common RPC state shared between client and server.
|
|
30
|
+
///
|
|
31
|
+
/// This struct manages the state machine for an RPC call, including:
|
|
32
|
+
/// - Message queuing and delivery
|
|
33
|
+
/// - Completion tracking
|
|
34
|
+
/// - Error handling
|
|
35
|
+
/// - Cancellation
|
|
36
|
+
pub struct CommonRpc {
|
|
37
|
+
/// Context for this RPC.
|
|
38
|
+
ctx: Context,
|
|
39
|
+
|
|
40
|
+
/// Service identifier.
|
|
41
|
+
service: String,
|
|
42
|
+
|
|
43
|
+
/// Method identifier.
|
|
44
|
+
method: String,
|
|
45
|
+
|
|
46
|
+
/// Whether we have completed locally (sent complete/cancel).
|
|
47
|
+
/// Note: not guarded by the mutex, uses atomic operations.
|
|
48
|
+
local_completed: AtomicBool,
|
|
49
|
+
|
|
50
|
+
/// Packet writer.
|
|
51
|
+
writer: Arc<dyn PacketWriter>,
|
|
52
|
+
|
|
53
|
+
/// Notification for state changes.
|
|
54
|
+
notify: Notify,
|
|
55
|
+
|
|
56
|
+
/// Internal state protected by mutex.
|
|
57
|
+
state: Mutex<RpcState>,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
/// Internal RPC state.
|
|
61
|
+
struct RpcState {
|
|
62
|
+
/// Queue of incoming data messages.
|
|
63
|
+
/// Note: messages may be len() == 0 (empty data with data_is_zero).
|
|
64
|
+
data_queue: VecDeque<Bytes>,
|
|
65
|
+
|
|
66
|
+
/// Whether the remote side has closed the stream.
|
|
67
|
+
data_closed: bool,
|
|
68
|
+
|
|
69
|
+
/// Error from the remote side, if any.
|
|
70
|
+
remote_err: Option<String>,
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
impl CommonRpc {
|
|
74
|
+
/// Creates a new CommonRpc.
|
|
75
|
+
pub fn new(
|
|
76
|
+
ctx: Context,
|
|
77
|
+
service: String,
|
|
78
|
+
method: String,
|
|
79
|
+
writer: Arc<dyn PacketWriter>,
|
|
80
|
+
) -> Self {
|
|
81
|
+
Self {
|
|
82
|
+
ctx,
|
|
83
|
+
service,
|
|
84
|
+
method,
|
|
85
|
+
local_completed: AtomicBool::new(false),
|
|
86
|
+
writer,
|
|
87
|
+
notify: Notify::new(),
|
|
88
|
+
state: Mutex::new(RpcState {
|
|
89
|
+
data_queue: VecDeque::new(),
|
|
90
|
+
data_closed: false,
|
|
91
|
+
remote_err: None,
|
|
92
|
+
}),
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
/// Returns the context for this RPC.
|
|
97
|
+
pub fn context(&self) -> &Context {
|
|
98
|
+
&self.ctx
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
/// Returns the service ID.
|
|
102
|
+
pub fn service(&self) -> &str {
|
|
103
|
+
&self.service
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
/// Returns the method ID.
|
|
107
|
+
pub fn method(&self) -> &str {
|
|
108
|
+
&self.method
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
/// Returns true if the RPC has completed locally.
|
|
112
|
+
pub fn is_local_completed(&self) -> bool {
|
|
113
|
+
self.local_completed.load(Ordering::SeqCst)
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
/// Waits for the RPC to finish (remote end closed the stream).
|
|
117
|
+
///
|
|
118
|
+
/// This matches the Go implementation's `Wait(ctx context.Context) error`.
|
|
119
|
+
pub async fn wait(&self) -> Result<()> {
|
|
120
|
+
loop {
|
|
121
|
+
// Check current state
|
|
122
|
+
{
|
|
123
|
+
let state = self.state.lock().await;
|
|
124
|
+
|
|
125
|
+
// Check cancellation first - local cancellation takes priority
|
|
126
|
+
if self.ctx.is_cancelled() {
|
|
127
|
+
return Err(Error::Cancelled);
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
if let Some(ref err) = state.remote_err {
|
|
131
|
+
return Err(Error::Remote(err.clone()));
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
if state.data_closed {
|
|
135
|
+
return Ok(());
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
// Wait for notification or cancellation
|
|
140
|
+
tokio::select! {
|
|
141
|
+
_ = self.notify.notified() => continue,
|
|
142
|
+
_ = self.ctx.cancelled() => return Err(Error::Cancelled),
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
/// Reads one message from the data queue, blocking until available.
|
|
148
|
+
///
|
|
149
|
+
/// Returns `Err(Error::StreamClosed)` if the stream ended without a packet.
|
|
150
|
+
/// This matches the Go implementation which returns `io.EOF`.
|
|
151
|
+
pub async fn read_one(&self) -> Result<Bytes> {
|
|
152
|
+
loop {
|
|
153
|
+
// Try to get a message from the queue first, before checking cancellation.
|
|
154
|
+
// This ensures we drain any pending messages even if the context is cancelled.
|
|
155
|
+
{
|
|
156
|
+
let mut state = self.state.lock().await;
|
|
157
|
+
|
|
158
|
+
if let Some(data) = state.data_queue.pop_front() {
|
|
159
|
+
return Ok(data);
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
// Check if the stream is closed (graceful close from remote)
|
|
163
|
+
if state.data_closed {
|
|
164
|
+
if let Some(ref err) = state.remote_err {
|
|
165
|
+
return Err(Error::Remote(err.clone()));
|
|
166
|
+
}
|
|
167
|
+
return Err(Error::StreamClosed);
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
// Now check for cancellation - only if no data is available
|
|
172
|
+
// and the stream isn't properly closed.
|
|
173
|
+
if self.ctx.is_cancelled() {
|
|
174
|
+
// If context cancelled and data not closed, close it now
|
|
175
|
+
let mut state = self.state.lock().await;
|
|
176
|
+
if !state.data_closed {
|
|
177
|
+
state.data_closed = true;
|
|
178
|
+
drop(state);
|
|
179
|
+
let _ = self.writer.close().await;
|
|
180
|
+
self.ctx.cancel();
|
|
181
|
+
self.notify.notify_waiters();
|
|
182
|
+
}
|
|
183
|
+
return Err(Error::Cancelled);
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
// Wait for notification or cancellation
|
|
187
|
+
tokio::select! {
|
|
188
|
+
_ = self.notify.notified() => continue,
|
|
189
|
+
_ = self.ctx.cancelled() => {
|
|
190
|
+
// Loop will handle the cancellation
|
|
191
|
+
continue;
|
|
192
|
+
}
|
|
193
|
+
}
|
|
194
|
+
}
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
/// Writes a CallData packet.
|
|
198
|
+
///
|
|
199
|
+
/// # Arguments
|
|
200
|
+
/// * `data` - Optional data to send
|
|
201
|
+
/// * `complete` - Whether this completes the RPC
|
|
202
|
+
/// * `error` - Optional error message
|
|
203
|
+
pub async fn write_call_data(
|
|
204
|
+
&self,
|
|
205
|
+
data: Option<Bytes>,
|
|
206
|
+
complete: bool,
|
|
207
|
+
error: Option<String>,
|
|
208
|
+
) -> Result<()> {
|
|
209
|
+
let should_complete = complete || error.is_some();
|
|
210
|
+
|
|
211
|
+
// Check if already completed
|
|
212
|
+
if should_complete {
|
|
213
|
+
if self
|
|
214
|
+
.local_completed
|
|
215
|
+
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
|
|
216
|
+
.is_err()
|
|
217
|
+
{
|
|
218
|
+
// If we're just marking completion and already completed, allow it (no-op)
|
|
219
|
+
// This matches Go behavior
|
|
220
|
+
if complete && data.is_none() && error.is_none() {
|
|
221
|
+
return Ok(());
|
|
222
|
+
}
|
|
223
|
+
return Err(Error::Completed);
|
|
224
|
+
}
|
|
225
|
+
} else if self.local_completed.load(Ordering::SeqCst) {
|
|
226
|
+
return Err(Error::Completed);
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
let packet = new_call_data_full(data, complete, error);
|
|
230
|
+
self.writer.write_packet(packet).await
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
/// Writes a CallCancel packet.
|
|
234
|
+
///
|
|
235
|
+
/// This atomically checks and sets the completed flag, then sends the cancel.
|
|
236
|
+
pub async fn write_call_cancel(&self) -> Result<()> {
|
|
237
|
+
// Use atomic swap to check and set completion atomically
|
|
238
|
+
if self.local_completed.swap(true, Ordering::SeqCst) {
|
|
239
|
+
return Err(Error::Completed);
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
self.writer.write_packet(new_call_cancel()).await
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
/// Handles an incoming CallData packet.
|
|
246
|
+
pub async fn handle_call_data(&self, call_data: CallData) -> Result<()> {
|
|
247
|
+
let mut state = self.state.lock().await;
|
|
248
|
+
|
|
249
|
+
// Check if already closed
|
|
250
|
+
if state.data_closed {
|
|
251
|
+
// If the packet is just indicating the call is complete, ignore it
|
|
252
|
+
// This matches Go behavior
|
|
253
|
+
if call_data.complete {
|
|
254
|
+
return Ok(());
|
|
255
|
+
}
|
|
256
|
+
return Err(Error::Completed);
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
// Extract data if present
|
|
260
|
+
if let Some(data) = decode_optional_data(call_data.data, call_data.data_is_zero) {
|
|
261
|
+
state.data_queue.push_back(data);
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
// Handle completion or error
|
|
265
|
+
if !call_data.error.is_empty() {
|
|
266
|
+
state.remote_err = Some(call_data.error);
|
|
267
|
+
state.data_closed = true;
|
|
268
|
+
} else if call_data.complete {
|
|
269
|
+
state.data_closed = true;
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
// Notify waiters
|
|
273
|
+
drop(state);
|
|
274
|
+
self.notify.notify_waiters();
|
|
275
|
+
|
|
276
|
+
Ok(())
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
/// Handles a CallCancel packet.
|
|
280
|
+
pub async fn handle_call_cancel(&self) -> Result<()> {
|
|
281
|
+
self.handle_stream_close(Some("cancelled".to_string())).await
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
/// Handles stream close from the transport.
|
|
285
|
+
pub async fn handle_stream_close(&self, err: Option<String>) -> Result<()> {
|
|
286
|
+
let mut state = self.state.lock().await;
|
|
287
|
+
|
|
288
|
+
if let Some(e) = err {
|
|
289
|
+
if state.remote_err.is_none() {
|
|
290
|
+
state.remote_err = Some(e);
|
|
291
|
+
}
|
|
292
|
+
}
|
|
293
|
+
state.data_closed = true;
|
|
294
|
+
|
|
295
|
+
drop(state);
|
|
296
|
+
|
|
297
|
+
let _ = self.writer.close().await;
|
|
298
|
+
self.ctx.cancel();
|
|
299
|
+
self.notify.notify_waiters();
|
|
300
|
+
|
|
301
|
+
Ok(())
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
/// Closes the RPC, releasing resources.
|
|
305
|
+
///
|
|
306
|
+
/// This is called internally and handles cleanup.
|
|
307
|
+
/// Note: We don't set remote_err here because local cancellation is signaled
|
|
308
|
+
/// via ctx.is_cancelled(), which is checked first in wait()/read_one().
|
|
309
|
+
async fn close_locked(&self) {
|
|
310
|
+
let mut state = self.state.lock().await;
|
|
311
|
+
state.data_closed = true;
|
|
312
|
+
self.local_completed.store(true, Ordering::SeqCst);
|
|
313
|
+
// Don't set remote_err for local closure - rely on ctx.is_cancelled() instead
|
|
314
|
+
// to distinguish between local cancellation and remote errors.
|
|
315
|
+
drop(state);
|
|
316
|
+
|
|
317
|
+
let _ = self.writer.close().await;
|
|
318
|
+
self.notify.notify_waiters();
|
|
319
|
+
self.ctx.cancel();
|
|
320
|
+
}
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
/// Client-side RPC state machine.
|
|
324
|
+
pub struct ClientRpc {
|
|
325
|
+
common: CommonRpc,
|
|
326
|
+
/// Whether CallStart has been sent.
|
|
327
|
+
start_sent: AtomicBool,
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
impl ClientRpc {
|
|
331
|
+
/// Creates a new ClientRpc.
|
|
332
|
+
pub fn new(
|
|
333
|
+
ctx: Context,
|
|
334
|
+
service: String,
|
|
335
|
+
method: String,
|
|
336
|
+
writer: Arc<dyn PacketWriter>,
|
|
337
|
+
) -> Self {
|
|
338
|
+
Self {
|
|
339
|
+
common: CommonRpc::new(ctx, service, method, writer),
|
|
340
|
+
start_sent: AtomicBool::new(false),
|
|
341
|
+
}
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
/// Returns the context for this RPC.
|
|
345
|
+
pub fn context(&self) -> &Context {
|
|
346
|
+
self.common.context()
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
/// Returns the service ID.
|
|
350
|
+
pub fn service(&self) -> &str {
|
|
351
|
+
self.common.service()
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
/// Returns the method ID.
|
|
355
|
+
pub fn method(&self) -> &str {
|
|
356
|
+
self.common.method()
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
/// Waits for the RPC to finish.
|
|
360
|
+
pub async fn wait(&self) -> Result<()> {
|
|
361
|
+
self.common.wait().await
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
/// Starts the RPC call with optional initial data.
|
|
365
|
+
pub async fn start(&self, data: Option<Bytes>) -> Result<()> {
|
|
366
|
+
if self
|
|
367
|
+
.start_sent
|
|
368
|
+
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
|
|
369
|
+
.is_err()
|
|
370
|
+
{
|
|
371
|
+
return Err(Error::Completed);
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
// Check context before starting
|
|
375
|
+
if self.common.ctx.is_cancelled() {
|
|
376
|
+
self.common.ctx.cancel();
|
|
377
|
+
let _ = self.common.writer.close().await;
|
|
378
|
+
return Err(Error::Cancelled);
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
let packet = new_call_start(
|
|
382
|
+
self.common.service.clone(),
|
|
383
|
+
self.common.method.clone(),
|
|
384
|
+
data,
|
|
385
|
+
);
|
|
386
|
+
|
|
387
|
+
if let Err(e) = self.common.writer.write_packet(packet).await {
|
|
388
|
+
self.common.ctx.cancel();
|
|
389
|
+
let _ = self.common.writer.close().await;
|
|
390
|
+
return Err(e);
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
Ok(())
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
/// Handles an incoming packet.
|
|
397
|
+
pub async fn handle_packet(&self, packet: Packet) -> Result<()> {
|
|
398
|
+
// Validate the packet first
|
|
399
|
+
packet.validate()?;
|
|
400
|
+
|
|
401
|
+
match packet.body {
|
|
402
|
+
Some(Body::CallData(call_data)) => self.common.handle_call_data(call_data).await,
|
|
403
|
+
Some(Body::CallCancel(true)) => self.common.handle_call_cancel().await,
|
|
404
|
+
Some(Body::CallCancel(false)) => Ok(()),
|
|
405
|
+
Some(Body::CallStart(_)) => {
|
|
406
|
+
// Server-to-client calls not supported
|
|
407
|
+
Err(Error::UnrecognizedPacket)
|
|
408
|
+
}
|
|
409
|
+
None => Err(Error::EmptyPacket),
|
|
410
|
+
}
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
/// Handles stream close from the transport.
|
|
414
|
+
pub async fn handle_stream_close(&self, err: Option<String>) -> Result<()> {
|
|
415
|
+
self.common.handle_stream_close(err).await
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
/// Closes the client RPC.
|
|
419
|
+
///
|
|
420
|
+
/// This sends a cancel packet (if not already completed) and releases resources.
|
|
421
|
+
/// Matches the Go implementation's `Close()` behavior.
|
|
422
|
+
pub async fn close(&self) {
|
|
423
|
+
// Only proceed if writer was set (start was called)
|
|
424
|
+
if !self.start_sent.load(Ordering::SeqCst) {
|
|
425
|
+
return;
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
// Try to send cancel, ignore errors
|
|
429
|
+
let _ = self.common.write_call_cancel().await;
|
|
430
|
+
|
|
431
|
+
// Close resources
|
|
432
|
+
self.common.close_locked().await;
|
|
433
|
+
}
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
#[async_trait]
|
|
437
|
+
impl Stream for ClientRpc {
|
|
438
|
+
fn context(&self) -> &Context {
|
|
439
|
+
&self.common.ctx
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
async fn send_bytes(&self, data: Bytes) -> Result<()> {
|
|
443
|
+
self.common
|
|
444
|
+
.write_call_data(Some(data), false, None)
|
|
445
|
+
.await
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
async fn recv_bytes(&self) -> Result<Bytes> {
|
|
449
|
+
self.common.read_one().await
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
async fn close_send(&self) -> Result<()> {
|
|
453
|
+
self.common.write_call_data(None, true, None).await
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
async fn close(&self) -> Result<()> {
|
|
457
|
+
ClientRpc::close(self).await;
|
|
458
|
+
Ok(())
|
|
459
|
+
}
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
/// Server-side RPC state machine.
|
|
463
|
+
pub struct ServerRpc {
|
|
464
|
+
common: CommonRpc,
|
|
465
|
+
/// Initial data from CallStart, if any.
|
|
466
|
+
initial_data: Mutex<Option<Bytes>>,
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
impl ServerRpc {
|
|
470
|
+
/// Creates a new ServerRpc from a CallStart packet.
|
|
471
|
+
pub fn from_call_start(
|
|
472
|
+
ctx: Context,
|
|
473
|
+
call_start: CallStart,
|
|
474
|
+
writer: Arc<dyn PacketWriter>,
|
|
475
|
+
) -> Self {
|
|
476
|
+
let initial_data = decode_optional_data(call_start.data, call_start.data_is_zero);
|
|
477
|
+
|
|
478
|
+
Self {
|
|
479
|
+
common: CommonRpc::new(ctx, call_start.rpc_service, call_start.rpc_method, writer),
|
|
480
|
+
initial_data: Mutex::new(initial_data),
|
|
481
|
+
}
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
/// Returns the context for this RPC.
|
|
485
|
+
pub fn context(&self) -> &Context {
|
|
486
|
+
self.common.context()
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
/// Returns the service ID.
|
|
490
|
+
pub fn service(&self) -> &str {
|
|
491
|
+
self.common.service()
|
|
492
|
+
}
|
|
493
|
+
|
|
494
|
+
/// Returns the method ID.
|
|
495
|
+
pub fn method(&self) -> &str {
|
|
496
|
+
self.common.method()
|
|
497
|
+
}
|
|
498
|
+
|
|
499
|
+
/// Waits for the RPC to finish.
|
|
500
|
+
pub async fn wait(&self) -> Result<()> {
|
|
501
|
+
self.common.wait().await
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
/// Handles an incoming packet.
|
|
505
|
+
pub async fn handle_packet(&self, packet: Packet) -> Result<()> {
|
|
506
|
+
// Validate the packet first
|
|
507
|
+
packet.validate()?;
|
|
508
|
+
|
|
509
|
+
match packet.body {
|
|
510
|
+
Some(Body::CallData(call_data)) => self.common.handle_call_data(call_data).await,
|
|
511
|
+
Some(Body::CallCancel(true)) => self.common.handle_call_cancel().await,
|
|
512
|
+
Some(Body::CallCancel(false)) => Ok(()),
|
|
513
|
+
Some(Body::CallStart(_)) => {
|
|
514
|
+
// CallStart should only be sent once
|
|
515
|
+
Err(Error::DuplicateCallStart)
|
|
516
|
+
}
|
|
517
|
+
None => Err(Error::EmptyPacket),
|
|
518
|
+
}
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
/// Handles stream close from the transport.
|
|
522
|
+
pub async fn handle_stream_close(&self, err: Option<String>) -> Result<()> {
|
|
523
|
+
self.common.handle_stream_close(err).await
|
|
524
|
+
}
|
|
525
|
+
|
|
526
|
+
/// Sends an error response and closes.
|
|
527
|
+
pub async fn send_error(&self, error: String) -> Result<()> {
|
|
528
|
+
self.common.write_call_data(None, true, Some(error)).await
|
|
529
|
+
}
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
#[async_trait]
|
|
533
|
+
impl Stream for ServerRpc {
|
|
534
|
+
fn context(&self) -> &Context {
|
|
535
|
+
&self.common.ctx
|
|
536
|
+
}
|
|
537
|
+
|
|
538
|
+
async fn send_bytes(&self, data: Bytes) -> Result<()> {
|
|
539
|
+
self.common
|
|
540
|
+
.write_call_data(Some(data), false, None)
|
|
541
|
+
.await
|
|
542
|
+
}
|
|
543
|
+
|
|
544
|
+
async fn recv_bytes(&self) -> Result<Bytes> {
|
|
545
|
+
// First check for initial data
|
|
546
|
+
{
|
|
547
|
+
let mut initial = self.initial_data.lock().await;
|
|
548
|
+
if let Some(data) = initial.take() {
|
|
549
|
+
return Ok(data);
|
|
550
|
+
}
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
// Then read from the queue
|
|
554
|
+
self.common.read_one().await
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
async fn close_send(&self) -> Result<()> {
|
|
558
|
+
self.common.write_call_data(None, true, None).await
|
|
559
|
+
}
|
|
560
|
+
|
|
561
|
+
async fn close(&self) -> Result<()> {
|
|
562
|
+
// Mark as locally completed
|
|
563
|
+
self.common.local_completed.store(true, Ordering::SeqCst);
|
|
564
|
+
|
|
565
|
+
// Close the writer
|
|
566
|
+
self.common.writer.close().await?;
|
|
567
|
+
|
|
568
|
+
// Cancel the context
|
|
569
|
+
self.common.ctx.cancel();
|
|
570
|
+
|
|
571
|
+
Ok(())
|
|
572
|
+
}
|
|
573
|
+
}
|
|
574
|
+
|
|
575
|
+
#[cfg(test)]
|
|
576
|
+
mod tests {
|
|
577
|
+
use super::*;
|
|
578
|
+
use std::sync::Mutex as StdMutex;
|
|
579
|
+
|
|
580
|
+
/// Mock packet writer for testing.
|
|
581
|
+
struct MockWriter {
|
|
582
|
+
packets: StdMutex<Vec<Packet>>,
|
|
583
|
+
closed: AtomicBool,
|
|
584
|
+
}
|
|
585
|
+
|
|
586
|
+
impl MockWriter {
|
|
587
|
+
fn new() -> Self {
|
|
588
|
+
Self {
|
|
589
|
+
packets: StdMutex::new(Vec::new()),
|
|
590
|
+
closed: AtomicBool::new(false),
|
|
591
|
+
}
|
|
592
|
+
}
|
|
593
|
+
|
|
594
|
+
fn packets(&self) -> Vec<Packet> {
|
|
595
|
+
self.packets.lock().unwrap().clone()
|
|
596
|
+
}
|
|
597
|
+
|
|
598
|
+
fn is_closed(&self) -> bool {
|
|
599
|
+
self.closed.load(Ordering::SeqCst)
|
|
600
|
+
}
|
|
601
|
+
}
|
|
602
|
+
|
|
603
|
+
#[async_trait]
|
|
604
|
+
impl PacketWriter for MockWriter {
|
|
605
|
+
async fn write_packet(&self, packet: Packet) -> Result<()> {
|
|
606
|
+
self.packets.lock().unwrap().push(packet);
|
|
607
|
+
Ok(())
|
|
608
|
+
}
|
|
609
|
+
|
|
610
|
+
async fn close(&self) -> Result<()> {
|
|
611
|
+
self.closed.store(true, Ordering::SeqCst);
|
|
612
|
+
Ok(())
|
|
613
|
+
}
|
|
614
|
+
}
|
|
615
|
+
|
|
616
|
+
#[tokio::test]
|
|
617
|
+
async fn test_client_rpc_start() {
|
|
618
|
+
let writer = Arc::new(MockWriter::new());
|
|
619
|
+
let ctx = Context::new();
|
|
620
|
+
let rpc = ClientRpc::new(ctx, "test.Service".into(), "TestMethod".into(), writer.clone());
|
|
621
|
+
|
|
622
|
+
rpc.start(Some(Bytes::from(vec![1, 2, 3]))).await.unwrap();
|
|
623
|
+
|
|
624
|
+
let packets = writer.packets();
|
|
625
|
+
assert_eq!(packets.len(), 1);
|
|
626
|
+
|
|
627
|
+
match &packets[0].body {
|
|
628
|
+
Some(Body::CallStart(cs)) => {
|
|
629
|
+
assert_eq!(cs.rpc_service, "test.Service");
|
|
630
|
+
assert_eq!(cs.rpc_method, "TestMethod");
|
|
631
|
+
assert_eq!(cs.data, vec![1, 2, 3]);
|
|
632
|
+
assert!(!cs.data_is_zero);
|
|
633
|
+
}
|
|
634
|
+
_ => panic!("expected CallStart"),
|
|
635
|
+
}
|
|
636
|
+
}
|
|
637
|
+
|
|
638
|
+
#[tokio::test]
|
|
639
|
+
async fn test_client_rpc_double_start_fails() {
|
|
640
|
+
let writer = Arc::new(MockWriter::new());
|
|
641
|
+
let ctx = Context::new();
|
|
642
|
+
let rpc = ClientRpc::new(ctx, "test.Service".into(), "TestMethod".into(), writer);
|
|
643
|
+
|
|
644
|
+
rpc.start(None).await.unwrap();
|
|
645
|
+
let result = rpc.start(None).await;
|
|
646
|
+
assert!(matches!(result, Err(Error::Completed)));
|
|
647
|
+
}
|
|
648
|
+
|
|
649
|
+
#[tokio::test]
|
|
650
|
+
async fn test_client_rpc_close_sends_cancel() {
|
|
651
|
+
let writer = Arc::new(MockWriter::new());
|
|
652
|
+
let ctx = Context::new();
|
|
653
|
+
let rpc = ClientRpc::new(ctx, "test.Service".into(), "TestMethod".into(), writer.clone());
|
|
654
|
+
|
|
655
|
+
rpc.start(None).await.unwrap();
|
|
656
|
+
rpc.close().await;
|
|
657
|
+
|
|
658
|
+
let packets = writer.packets();
|
|
659
|
+
assert_eq!(packets.len(), 2);
|
|
660
|
+
|
|
661
|
+
// Second packet should be CallCancel
|
|
662
|
+
assert!(packets[1].is_call_cancel());
|
|
663
|
+
assert!(writer.is_closed());
|
|
664
|
+
}
|
|
665
|
+
|
|
666
|
+
#[tokio::test]
|
|
667
|
+
async fn test_server_rpc_from_call_start() {
|
|
668
|
+
let call_start = CallStart {
|
|
669
|
+
rpc_service: "test.Service".into(),
|
|
670
|
+
rpc_method: "TestMethod".into(),
|
|
671
|
+
data: vec![1, 2, 3],
|
|
672
|
+
data_is_zero: false,
|
|
673
|
+
};
|
|
674
|
+
|
|
675
|
+
let writer = Arc::new(MockWriter::new());
|
|
676
|
+
let ctx = Context::new();
|
|
677
|
+
let rpc = ServerRpc::from_call_start(ctx, call_start, writer);
|
|
678
|
+
|
|
679
|
+
assert_eq!(rpc.service(), "test.Service");
|
|
680
|
+
assert_eq!(rpc.method(), "TestMethod");
|
|
681
|
+
|
|
682
|
+
// First recv should return the initial data
|
|
683
|
+
let data = rpc.recv_bytes().await.unwrap();
|
|
684
|
+
assert_eq!(&data[..], &[1, 2, 3]);
|
|
685
|
+
}
|
|
686
|
+
|
|
687
|
+
#[tokio::test]
|
|
688
|
+
async fn test_common_rpc_read_one_with_data() {
|
|
689
|
+
let writer = Arc::new(MockWriter::new());
|
|
690
|
+
let ctx = Context::new();
|
|
691
|
+
let rpc = CommonRpc::new(ctx, "svc".into(), "method".into(), writer);
|
|
692
|
+
|
|
693
|
+
// Simulate receiving data
|
|
694
|
+
let call_data = CallData {
|
|
695
|
+
data: vec![1, 2, 3],
|
|
696
|
+
data_is_zero: false,
|
|
697
|
+
complete: false,
|
|
698
|
+
error: String::new(),
|
|
699
|
+
};
|
|
700
|
+
rpc.handle_call_data(call_data).await.unwrap();
|
|
701
|
+
|
|
702
|
+
let data = rpc.read_one().await.unwrap();
|
|
703
|
+
assert_eq!(&data[..], &[1, 2, 3]);
|
|
704
|
+
}
|
|
705
|
+
|
|
706
|
+
#[tokio::test]
|
|
707
|
+
async fn test_common_rpc_read_one_stream_closed() {
|
|
708
|
+
let writer = Arc::new(MockWriter::new());
|
|
709
|
+
let ctx = Context::new();
|
|
710
|
+
let rpc = CommonRpc::new(ctx, "svc".into(), "method".into(), writer);
|
|
711
|
+
|
|
712
|
+
// Simulate stream close
|
|
713
|
+
let call_data = CallData {
|
|
714
|
+
data: vec![],
|
|
715
|
+
data_is_zero: false,
|
|
716
|
+
complete: true,
|
|
717
|
+
error: String::new(),
|
|
718
|
+
};
|
|
719
|
+
rpc.handle_call_data(call_data).await.unwrap();
|
|
720
|
+
|
|
721
|
+
let result = rpc.read_one().await;
|
|
722
|
+
assert!(matches!(result, Err(Error::StreamClosed)));
|
|
723
|
+
}
|
|
724
|
+
|
|
725
|
+
#[tokio::test]
|
|
726
|
+
async fn test_common_rpc_read_one_with_error() {
|
|
727
|
+
let writer = Arc::new(MockWriter::new());
|
|
728
|
+
let ctx = Context::new();
|
|
729
|
+
let rpc = CommonRpc::new(ctx, "svc".into(), "method".into(), writer);
|
|
730
|
+
|
|
731
|
+
// Simulate error
|
|
732
|
+
let call_data = CallData {
|
|
733
|
+
data: vec![],
|
|
734
|
+
data_is_zero: false,
|
|
735
|
+
complete: true,
|
|
736
|
+
error: "test error".into(),
|
|
737
|
+
};
|
|
738
|
+
rpc.handle_call_data(call_data).await.unwrap();
|
|
739
|
+
|
|
740
|
+
let result = rpc.read_one().await;
|
|
741
|
+
match result {
|
|
742
|
+
Err(Error::Remote(msg)) => assert_eq!(msg, "test error"),
|
|
743
|
+
_ => panic!("expected Remote error"),
|
|
744
|
+
}
|
|
745
|
+
}
|
|
746
|
+
|
|
747
|
+
#[tokio::test]
|
|
748
|
+
async fn test_write_call_data_after_complete() {
|
|
749
|
+
let writer = Arc::new(MockWriter::new());
|
|
750
|
+
let ctx = Context::new();
|
|
751
|
+
let rpc = CommonRpc::new(ctx, "svc".into(), "method".into(), writer);
|
|
752
|
+
|
|
753
|
+
// Complete the RPC
|
|
754
|
+
rpc.write_call_data(None, true, None).await.unwrap();
|
|
755
|
+
|
|
756
|
+
// Trying to send more data should fail
|
|
757
|
+
let result = rpc.write_call_data(Some(Bytes::from(vec![1])), false, None).await;
|
|
758
|
+
assert!(matches!(result, Err(Error::Completed)));
|
|
759
|
+
}
|
|
760
|
+
|
|
761
|
+
#[tokio::test]
|
|
762
|
+
async fn test_write_call_cancel() {
|
|
763
|
+
let writer = Arc::new(MockWriter::new());
|
|
764
|
+
let ctx = Context::new();
|
|
765
|
+
let rpc = CommonRpc::new(ctx, "svc".into(), "method".into(), writer.clone());
|
|
766
|
+
|
|
767
|
+
rpc.write_call_cancel().await.unwrap();
|
|
768
|
+
|
|
769
|
+
let packets = writer.packets();
|
|
770
|
+
assert_eq!(packets.len(), 1);
|
|
771
|
+
assert!(packets[0].is_call_cancel());
|
|
772
|
+
|
|
773
|
+
// Second cancel should fail
|
|
774
|
+
let result = rpc.write_call_cancel().await;
|
|
775
|
+
assert!(matches!(result, Err(Error::Completed)));
|
|
776
|
+
}
|
|
777
|
+
}
|