starpc 0.41.2 → 0.43.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/srpc/rpc.rs ADDED
@@ -0,0 +1,777 @@
1
+ //! RPC state machines for client and server.
2
+ //!
3
+ //! This module provides the core RPC state machines that manage the lifecycle
4
+ //! of RPC calls, matching the behavior of the Go and TypeScript implementations.
5
+
6
+ use async_trait::async_trait;
7
+ use bytes::Bytes;
8
+ use std::collections::VecDeque;
9
+ use std::sync::atomic::{AtomicBool, Ordering};
10
+ use std::sync::Arc;
11
+ use tokio::sync::{Mutex, Notify};
12
+
13
+ use crate::error::{Error, Result};
14
+ use crate::packet::{new_call_cancel, new_call_data_full, new_call_start, Validate};
15
+ use crate::proto::{packet::Body, CallData, CallStart, Packet};
16
+ use crate::stream::{Context, Stream};
17
+ use crate::transport::decode_optional_data;
18
+
19
+ /// Trait for writing packets to the transport.
20
+ #[async_trait]
21
+ pub trait PacketWriter: Send + Sync {
22
+ /// Writes a packet to the transport.
23
+ async fn write_packet(&self, packet: Packet) -> Result<()>;
24
+
25
+ /// Closes the writer.
26
+ async fn close(&self) -> Result<()>;
27
+ }
28
+
29
+ /// Common RPC state shared between client and server.
30
+ ///
31
+ /// This struct manages the state machine for an RPC call, including:
32
+ /// - Message queuing and delivery
33
+ /// - Completion tracking
34
+ /// - Error handling
35
+ /// - Cancellation
36
+ pub struct CommonRpc {
37
+ /// Context for this RPC.
38
+ ctx: Context,
39
+
40
+ /// Service identifier.
41
+ service: String,
42
+
43
+ /// Method identifier.
44
+ method: String,
45
+
46
+ /// Whether we have completed locally (sent complete/cancel).
47
+ /// Note: not guarded by the mutex, uses atomic operations.
48
+ local_completed: AtomicBool,
49
+
50
+ /// Packet writer.
51
+ writer: Arc<dyn PacketWriter>,
52
+
53
+ /// Notification for state changes.
54
+ notify: Notify,
55
+
56
+ /// Internal state protected by mutex.
57
+ state: Mutex<RpcState>,
58
+ }
59
+
60
+ /// Internal RPC state.
61
+ struct RpcState {
62
+ /// Queue of incoming data messages.
63
+ /// Note: messages may be len() == 0 (empty data with data_is_zero).
64
+ data_queue: VecDeque<Bytes>,
65
+
66
+ /// Whether the remote side has closed the stream.
67
+ data_closed: bool,
68
+
69
+ /// Error from the remote side, if any.
70
+ remote_err: Option<String>,
71
+ }
72
+
73
+ impl CommonRpc {
74
+ /// Creates a new CommonRpc.
75
+ pub fn new(
76
+ ctx: Context,
77
+ service: String,
78
+ method: String,
79
+ writer: Arc<dyn PacketWriter>,
80
+ ) -> Self {
81
+ Self {
82
+ ctx,
83
+ service,
84
+ method,
85
+ local_completed: AtomicBool::new(false),
86
+ writer,
87
+ notify: Notify::new(),
88
+ state: Mutex::new(RpcState {
89
+ data_queue: VecDeque::new(),
90
+ data_closed: false,
91
+ remote_err: None,
92
+ }),
93
+ }
94
+ }
95
+
96
+ /// Returns the context for this RPC.
97
+ pub fn context(&self) -> &Context {
98
+ &self.ctx
99
+ }
100
+
101
+ /// Returns the service ID.
102
+ pub fn service(&self) -> &str {
103
+ &self.service
104
+ }
105
+
106
+ /// Returns the method ID.
107
+ pub fn method(&self) -> &str {
108
+ &self.method
109
+ }
110
+
111
+ /// Returns true if the RPC has completed locally.
112
+ pub fn is_local_completed(&self) -> bool {
113
+ self.local_completed.load(Ordering::SeqCst)
114
+ }
115
+
116
+ /// Waits for the RPC to finish (remote end closed the stream).
117
+ ///
118
+ /// This matches the Go implementation's `Wait(ctx context.Context) error`.
119
+ pub async fn wait(&self) -> Result<()> {
120
+ loop {
121
+ // Check current state
122
+ {
123
+ let state = self.state.lock().await;
124
+
125
+ // Check cancellation first - local cancellation takes priority
126
+ if self.ctx.is_cancelled() {
127
+ return Err(Error::Cancelled);
128
+ }
129
+
130
+ if let Some(ref err) = state.remote_err {
131
+ return Err(Error::Remote(err.clone()));
132
+ }
133
+
134
+ if state.data_closed {
135
+ return Ok(());
136
+ }
137
+ }
138
+
139
+ // Wait for notification or cancellation
140
+ tokio::select! {
141
+ _ = self.notify.notified() => continue,
142
+ _ = self.ctx.cancelled() => return Err(Error::Cancelled),
143
+ }
144
+ }
145
+ }
146
+
147
+ /// Reads one message from the data queue, blocking until available.
148
+ ///
149
+ /// Returns `Err(Error::StreamClosed)` if the stream ended without a packet.
150
+ /// This matches the Go implementation which returns `io.EOF`.
151
+ pub async fn read_one(&self) -> Result<Bytes> {
152
+ loop {
153
+ // Try to get a message from the queue first, before checking cancellation.
154
+ // This ensures we drain any pending messages even if the context is cancelled.
155
+ {
156
+ let mut state = self.state.lock().await;
157
+
158
+ if let Some(data) = state.data_queue.pop_front() {
159
+ return Ok(data);
160
+ }
161
+
162
+ // Check if the stream is closed (graceful close from remote)
163
+ if state.data_closed {
164
+ if let Some(ref err) = state.remote_err {
165
+ return Err(Error::Remote(err.clone()));
166
+ }
167
+ return Err(Error::StreamClosed);
168
+ }
169
+ }
170
+
171
+ // Now check for cancellation - only if no data is available
172
+ // and the stream isn't properly closed.
173
+ if self.ctx.is_cancelled() {
174
+ // If context cancelled and data not closed, close it now
175
+ let mut state = self.state.lock().await;
176
+ if !state.data_closed {
177
+ state.data_closed = true;
178
+ drop(state);
179
+ let _ = self.writer.close().await;
180
+ self.ctx.cancel();
181
+ self.notify.notify_waiters();
182
+ }
183
+ return Err(Error::Cancelled);
184
+ }
185
+
186
+ // Wait for notification or cancellation
187
+ tokio::select! {
188
+ _ = self.notify.notified() => continue,
189
+ _ = self.ctx.cancelled() => {
190
+ // Loop will handle the cancellation
191
+ continue;
192
+ }
193
+ }
194
+ }
195
+ }
196
+
197
+ /// Writes a CallData packet.
198
+ ///
199
+ /// # Arguments
200
+ /// * `data` - Optional data to send
201
+ /// * `complete` - Whether this completes the RPC
202
+ /// * `error` - Optional error message
203
+ pub async fn write_call_data(
204
+ &self,
205
+ data: Option<Bytes>,
206
+ complete: bool,
207
+ error: Option<String>,
208
+ ) -> Result<()> {
209
+ let should_complete = complete || error.is_some();
210
+
211
+ // Check if already completed
212
+ if should_complete {
213
+ if self
214
+ .local_completed
215
+ .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
216
+ .is_err()
217
+ {
218
+ // If we're just marking completion and already completed, allow it (no-op)
219
+ // This matches Go behavior
220
+ if complete && data.is_none() && error.is_none() {
221
+ return Ok(());
222
+ }
223
+ return Err(Error::Completed);
224
+ }
225
+ } else if self.local_completed.load(Ordering::SeqCst) {
226
+ return Err(Error::Completed);
227
+ }
228
+
229
+ let packet = new_call_data_full(data, complete, error);
230
+ self.writer.write_packet(packet).await
231
+ }
232
+
233
+ /// Writes a CallCancel packet.
234
+ ///
235
+ /// This atomically checks and sets the completed flag, then sends the cancel.
236
+ pub async fn write_call_cancel(&self) -> Result<()> {
237
+ // Use atomic swap to check and set completion atomically
238
+ if self.local_completed.swap(true, Ordering::SeqCst) {
239
+ return Err(Error::Completed);
240
+ }
241
+
242
+ self.writer.write_packet(new_call_cancel()).await
243
+ }
244
+
245
+ /// Handles an incoming CallData packet.
246
+ pub async fn handle_call_data(&self, call_data: CallData) -> Result<()> {
247
+ let mut state = self.state.lock().await;
248
+
249
+ // Check if already closed
250
+ if state.data_closed {
251
+ // If the packet is just indicating the call is complete, ignore it
252
+ // This matches Go behavior
253
+ if call_data.complete {
254
+ return Ok(());
255
+ }
256
+ return Err(Error::Completed);
257
+ }
258
+
259
+ // Extract data if present
260
+ if let Some(data) = decode_optional_data(call_data.data, call_data.data_is_zero) {
261
+ state.data_queue.push_back(data);
262
+ }
263
+
264
+ // Handle completion or error
265
+ if !call_data.error.is_empty() {
266
+ state.remote_err = Some(call_data.error);
267
+ state.data_closed = true;
268
+ } else if call_data.complete {
269
+ state.data_closed = true;
270
+ }
271
+
272
+ // Notify waiters
273
+ drop(state);
274
+ self.notify.notify_waiters();
275
+
276
+ Ok(())
277
+ }
278
+
279
+ /// Handles a CallCancel packet.
280
+ pub async fn handle_call_cancel(&self) -> Result<()> {
281
+ self.handle_stream_close(Some("cancelled".to_string())).await
282
+ }
283
+
284
+ /// Handles stream close from the transport.
285
+ pub async fn handle_stream_close(&self, err: Option<String>) -> Result<()> {
286
+ let mut state = self.state.lock().await;
287
+
288
+ if let Some(e) = err {
289
+ if state.remote_err.is_none() {
290
+ state.remote_err = Some(e);
291
+ }
292
+ }
293
+ state.data_closed = true;
294
+
295
+ drop(state);
296
+
297
+ let _ = self.writer.close().await;
298
+ self.ctx.cancel();
299
+ self.notify.notify_waiters();
300
+
301
+ Ok(())
302
+ }
303
+
304
+ /// Closes the RPC, releasing resources.
305
+ ///
306
+ /// This is called internally and handles cleanup.
307
+ /// Note: We don't set remote_err here because local cancellation is signaled
308
+ /// via ctx.is_cancelled(), which is checked first in wait()/read_one().
309
+ async fn close_locked(&self) {
310
+ let mut state = self.state.lock().await;
311
+ state.data_closed = true;
312
+ self.local_completed.store(true, Ordering::SeqCst);
313
+ // Don't set remote_err for local closure - rely on ctx.is_cancelled() instead
314
+ // to distinguish between local cancellation and remote errors.
315
+ drop(state);
316
+
317
+ let _ = self.writer.close().await;
318
+ self.notify.notify_waiters();
319
+ self.ctx.cancel();
320
+ }
321
+ }
322
+
323
+ /// Client-side RPC state machine.
324
+ pub struct ClientRpc {
325
+ common: CommonRpc,
326
+ /// Whether CallStart has been sent.
327
+ start_sent: AtomicBool,
328
+ }
329
+
330
+ impl ClientRpc {
331
+ /// Creates a new ClientRpc.
332
+ pub fn new(
333
+ ctx: Context,
334
+ service: String,
335
+ method: String,
336
+ writer: Arc<dyn PacketWriter>,
337
+ ) -> Self {
338
+ Self {
339
+ common: CommonRpc::new(ctx, service, method, writer),
340
+ start_sent: AtomicBool::new(false),
341
+ }
342
+ }
343
+
344
+ /// Returns the context for this RPC.
345
+ pub fn context(&self) -> &Context {
346
+ self.common.context()
347
+ }
348
+
349
+ /// Returns the service ID.
350
+ pub fn service(&self) -> &str {
351
+ self.common.service()
352
+ }
353
+
354
+ /// Returns the method ID.
355
+ pub fn method(&self) -> &str {
356
+ self.common.method()
357
+ }
358
+
359
+ /// Waits for the RPC to finish.
360
+ pub async fn wait(&self) -> Result<()> {
361
+ self.common.wait().await
362
+ }
363
+
364
+ /// Starts the RPC call with optional initial data.
365
+ pub async fn start(&self, data: Option<Bytes>) -> Result<()> {
366
+ if self
367
+ .start_sent
368
+ .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
369
+ .is_err()
370
+ {
371
+ return Err(Error::Completed);
372
+ }
373
+
374
+ // Check context before starting
375
+ if self.common.ctx.is_cancelled() {
376
+ self.common.ctx.cancel();
377
+ let _ = self.common.writer.close().await;
378
+ return Err(Error::Cancelled);
379
+ }
380
+
381
+ let packet = new_call_start(
382
+ self.common.service.clone(),
383
+ self.common.method.clone(),
384
+ data,
385
+ );
386
+
387
+ if let Err(e) = self.common.writer.write_packet(packet).await {
388
+ self.common.ctx.cancel();
389
+ let _ = self.common.writer.close().await;
390
+ return Err(e);
391
+ }
392
+
393
+ Ok(())
394
+ }
395
+
396
+ /// Handles an incoming packet.
397
+ pub async fn handle_packet(&self, packet: Packet) -> Result<()> {
398
+ // Validate the packet first
399
+ packet.validate()?;
400
+
401
+ match packet.body {
402
+ Some(Body::CallData(call_data)) => self.common.handle_call_data(call_data).await,
403
+ Some(Body::CallCancel(true)) => self.common.handle_call_cancel().await,
404
+ Some(Body::CallCancel(false)) => Ok(()),
405
+ Some(Body::CallStart(_)) => {
406
+ // Server-to-client calls not supported
407
+ Err(Error::UnrecognizedPacket)
408
+ }
409
+ None => Err(Error::EmptyPacket),
410
+ }
411
+ }
412
+
413
+ /// Handles stream close from the transport.
414
+ pub async fn handle_stream_close(&self, err: Option<String>) -> Result<()> {
415
+ self.common.handle_stream_close(err).await
416
+ }
417
+
418
+ /// Closes the client RPC.
419
+ ///
420
+ /// This sends a cancel packet (if not already completed) and releases resources.
421
+ /// Matches the Go implementation's `Close()` behavior.
422
+ pub async fn close(&self) {
423
+ // Only proceed if writer was set (start was called)
424
+ if !self.start_sent.load(Ordering::SeqCst) {
425
+ return;
426
+ }
427
+
428
+ // Try to send cancel, ignore errors
429
+ let _ = self.common.write_call_cancel().await;
430
+
431
+ // Close resources
432
+ self.common.close_locked().await;
433
+ }
434
+ }
435
+
436
+ #[async_trait]
437
+ impl Stream for ClientRpc {
438
+ fn context(&self) -> &Context {
439
+ &self.common.ctx
440
+ }
441
+
442
+ async fn send_bytes(&self, data: Bytes) -> Result<()> {
443
+ self.common
444
+ .write_call_data(Some(data), false, None)
445
+ .await
446
+ }
447
+
448
+ async fn recv_bytes(&self) -> Result<Bytes> {
449
+ self.common.read_one().await
450
+ }
451
+
452
+ async fn close_send(&self) -> Result<()> {
453
+ self.common.write_call_data(None, true, None).await
454
+ }
455
+
456
+ async fn close(&self) -> Result<()> {
457
+ ClientRpc::close(self).await;
458
+ Ok(())
459
+ }
460
+ }
461
+
462
+ /// Server-side RPC state machine.
463
+ pub struct ServerRpc {
464
+ common: CommonRpc,
465
+ /// Initial data from CallStart, if any.
466
+ initial_data: Mutex<Option<Bytes>>,
467
+ }
468
+
469
+ impl ServerRpc {
470
+ /// Creates a new ServerRpc from a CallStart packet.
471
+ pub fn from_call_start(
472
+ ctx: Context,
473
+ call_start: CallStart,
474
+ writer: Arc<dyn PacketWriter>,
475
+ ) -> Self {
476
+ let initial_data = decode_optional_data(call_start.data, call_start.data_is_zero);
477
+
478
+ Self {
479
+ common: CommonRpc::new(ctx, call_start.rpc_service, call_start.rpc_method, writer),
480
+ initial_data: Mutex::new(initial_data),
481
+ }
482
+ }
483
+
484
+ /// Returns the context for this RPC.
485
+ pub fn context(&self) -> &Context {
486
+ self.common.context()
487
+ }
488
+
489
+ /// Returns the service ID.
490
+ pub fn service(&self) -> &str {
491
+ self.common.service()
492
+ }
493
+
494
+ /// Returns the method ID.
495
+ pub fn method(&self) -> &str {
496
+ self.common.method()
497
+ }
498
+
499
+ /// Waits for the RPC to finish.
500
+ pub async fn wait(&self) -> Result<()> {
501
+ self.common.wait().await
502
+ }
503
+
504
+ /// Handles an incoming packet.
505
+ pub async fn handle_packet(&self, packet: Packet) -> Result<()> {
506
+ // Validate the packet first
507
+ packet.validate()?;
508
+
509
+ match packet.body {
510
+ Some(Body::CallData(call_data)) => self.common.handle_call_data(call_data).await,
511
+ Some(Body::CallCancel(true)) => self.common.handle_call_cancel().await,
512
+ Some(Body::CallCancel(false)) => Ok(()),
513
+ Some(Body::CallStart(_)) => {
514
+ // CallStart should only be sent once
515
+ Err(Error::DuplicateCallStart)
516
+ }
517
+ None => Err(Error::EmptyPacket),
518
+ }
519
+ }
520
+
521
+ /// Handles stream close from the transport.
522
+ pub async fn handle_stream_close(&self, err: Option<String>) -> Result<()> {
523
+ self.common.handle_stream_close(err).await
524
+ }
525
+
526
+ /// Sends an error response and closes.
527
+ pub async fn send_error(&self, error: String) -> Result<()> {
528
+ self.common.write_call_data(None, true, Some(error)).await
529
+ }
530
+ }
531
+
532
+ #[async_trait]
533
+ impl Stream for ServerRpc {
534
+ fn context(&self) -> &Context {
535
+ &self.common.ctx
536
+ }
537
+
538
+ async fn send_bytes(&self, data: Bytes) -> Result<()> {
539
+ self.common
540
+ .write_call_data(Some(data), false, None)
541
+ .await
542
+ }
543
+
544
+ async fn recv_bytes(&self) -> Result<Bytes> {
545
+ // First check for initial data
546
+ {
547
+ let mut initial = self.initial_data.lock().await;
548
+ if let Some(data) = initial.take() {
549
+ return Ok(data);
550
+ }
551
+ }
552
+
553
+ // Then read from the queue
554
+ self.common.read_one().await
555
+ }
556
+
557
+ async fn close_send(&self) -> Result<()> {
558
+ self.common.write_call_data(None, true, None).await
559
+ }
560
+
561
+ async fn close(&self) -> Result<()> {
562
+ // Mark as locally completed
563
+ self.common.local_completed.store(true, Ordering::SeqCst);
564
+
565
+ // Close the writer
566
+ self.common.writer.close().await?;
567
+
568
+ // Cancel the context
569
+ self.common.ctx.cancel();
570
+
571
+ Ok(())
572
+ }
573
+ }
574
+
575
+ #[cfg(test)]
576
+ mod tests {
577
+ use super::*;
578
+ use std::sync::Mutex as StdMutex;
579
+
580
+ /// Mock packet writer for testing.
581
+ struct MockWriter {
582
+ packets: StdMutex<Vec<Packet>>,
583
+ closed: AtomicBool,
584
+ }
585
+
586
+ impl MockWriter {
587
+ fn new() -> Self {
588
+ Self {
589
+ packets: StdMutex::new(Vec::new()),
590
+ closed: AtomicBool::new(false),
591
+ }
592
+ }
593
+
594
+ fn packets(&self) -> Vec<Packet> {
595
+ self.packets.lock().unwrap().clone()
596
+ }
597
+
598
+ fn is_closed(&self) -> bool {
599
+ self.closed.load(Ordering::SeqCst)
600
+ }
601
+ }
602
+
603
+ #[async_trait]
604
+ impl PacketWriter for MockWriter {
605
+ async fn write_packet(&self, packet: Packet) -> Result<()> {
606
+ self.packets.lock().unwrap().push(packet);
607
+ Ok(())
608
+ }
609
+
610
+ async fn close(&self) -> Result<()> {
611
+ self.closed.store(true, Ordering::SeqCst);
612
+ Ok(())
613
+ }
614
+ }
615
+
616
+ #[tokio::test]
617
+ async fn test_client_rpc_start() {
618
+ let writer = Arc::new(MockWriter::new());
619
+ let ctx = Context::new();
620
+ let rpc = ClientRpc::new(ctx, "test.Service".into(), "TestMethod".into(), writer.clone());
621
+
622
+ rpc.start(Some(Bytes::from(vec![1, 2, 3]))).await.unwrap();
623
+
624
+ let packets = writer.packets();
625
+ assert_eq!(packets.len(), 1);
626
+
627
+ match &packets[0].body {
628
+ Some(Body::CallStart(cs)) => {
629
+ assert_eq!(cs.rpc_service, "test.Service");
630
+ assert_eq!(cs.rpc_method, "TestMethod");
631
+ assert_eq!(cs.data, vec![1, 2, 3]);
632
+ assert!(!cs.data_is_zero);
633
+ }
634
+ _ => panic!("expected CallStart"),
635
+ }
636
+ }
637
+
638
+ #[tokio::test]
639
+ async fn test_client_rpc_double_start_fails() {
640
+ let writer = Arc::new(MockWriter::new());
641
+ let ctx = Context::new();
642
+ let rpc = ClientRpc::new(ctx, "test.Service".into(), "TestMethod".into(), writer);
643
+
644
+ rpc.start(None).await.unwrap();
645
+ let result = rpc.start(None).await;
646
+ assert!(matches!(result, Err(Error::Completed)));
647
+ }
648
+
649
+ #[tokio::test]
650
+ async fn test_client_rpc_close_sends_cancel() {
651
+ let writer = Arc::new(MockWriter::new());
652
+ let ctx = Context::new();
653
+ let rpc = ClientRpc::new(ctx, "test.Service".into(), "TestMethod".into(), writer.clone());
654
+
655
+ rpc.start(None).await.unwrap();
656
+ rpc.close().await;
657
+
658
+ let packets = writer.packets();
659
+ assert_eq!(packets.len(), 2);
660
+
661
+ // Second packet should be CallCancel
662
+ assert!(packets[1].is_call_cancel());
663
+ assert!(writer.is_closed());
664
+ }
665
+
666
+ #[tokio::test]
667
+ async fn test_server_rpc_from_call_start() {
668
+ let call_start = CallStart {
669
+ rpc_service: "test.Service".into(),
670
+ rpc_method: "TestMethod".into(),
671
+ data: vec![1, 2, 3],
672
+ data_is_zero: false,
673
+ };
674
+
675
+ let writer = Arc::new(MockWriter::new());
676
+ let ctx = Context::new();
677
+ let rpc = ServerRpc::from_call_start(ctx, call_start, writer);
678
+
679
+ assert_eq!(rpc.service(), "test.Service");
680
+ assert_eq!(rpc.method(), "TestMethod");
681
+
682
+ // First recv should return the initial data
683
+ let data = rpc.recv_bytes().await.unwrap();
684
+ assert_eq!(&data[..], &[1, 2, 3]);
685
+ }
686
+
687
+ #[tokio::test]
688
+ async fn test_common_rpc_read_one_with_data() {
689
+ let writer = Arc::new(MockWriter::new());
690
+ let ctx = Context::new();
691
+ let rpc = CommonRpc::new(ctx, "svc".into(), "method".into(), writer);
692
+
693
+ // Simulate receiving data
694
+ let call_data = CallData {
695
+ data: vec![1, 2, 3],
696
+ data_is_zero: false,
697
+ complete: false,
698
+ error: String::new(),
699
+ };
700
+ rpc.handle_call_data(call_data).await.unwrap();
701
+
702
+ let data = rpc.read_one().await.unwrap();
703
+ assert_eq!(&data[..], &[1, 2, 3]);
704
+ }
705
+
706
+ #[tokio::test]
707
+ async fn test_common_rpc_read_one_stream_closed() {
708
+ let writer = Arc::new(MockWriter::new());
709
+ let ctx = Context::new();
710
+ let rpc = CommonRpc::new(ctx, "svc".into(), "method".into(), writer);
711
+
712
+ // Simulate stream close
713
+ let call_data = CallData {
714
+ data: vec![],
715
+ data_is_zero: false,
716
+ complete: true,
717
+ error: String::new(),
718
+ };
719
+ rpc.handle_call_data(call_data).await.unwrap();
720
+
721
+ let result = rpc.read_one().await;
722
+ assert!(matches!(result, Err(Error::StreamClosed)));
723
+ }
724
+
725
+ #[tokio::test]
726
+ async fn test_common_rpc_read_one_with_error() {
727
+ let writer = Arc::new(MockWriter::new());
728
+ let ctx = Context::new();
729
+ let rpc = CommonRpc::new(ctx, "svc".into(), "method".into(), writer);
730
+
731
+ // Simulate error
732
+ let call_data = CallData {
733
+ data: vec![],
734
+ data_is_zero: false,
735
+ complete: true,
736
+ error: "test error".into(),
737
+ };
738
+ rpc.handle_call_data(call_data).await.unwrap();
739
+
740
+ let result = rpc.read_one().await;
741
+ match result {
742
+ Err(Error::Remote(msg)) => assert_eq!(msg, "test error"),
743
+ _ => panic!("expected Remote error"),
744
+ }
745
+ }
746
+
747
+ #[tokio::test]
748
+ async fn test_write_call_data_after_complete() {
749
+ let writer = Arc::new(MockWriter::new());
750
+ let ctx = Context::new();
751
+ let rpc = CommonRpc::new(ctx, "svc".into(), "method".into(), writer);
752
+
753
+ // Complete the RPC
754
+ rpc.write_call_data(None, true, None).await.unwrap();
755
+
756
+ // Trying to send more data should fail
757
+ let result = rpc.write_call_data(Some(Bytes::from(vec![1])), false, None).await;
758
+ assert!(matches!(result, Err(Error::Completed)));
759
+ }
760
+
761
+ #[tokio::test]
762
+ async fn test_write_call_cancel() {
763
+ let writer = Arc::new(MockWriter::new());
764
+ let ctx = Context::new();
765
+ let rpc = CommonRpc::new(ctx, "svc".into(), "method".into(), writer.clone());
766
+
767
+ rpc.write_call_cancel().await.unwrap();
768
+
769
+ let packets = writer.packets();
770
+ assert_eq!(packets.len(), 1);
771
+ assert!(packets[0].is_call_cancel());
772
+
773
+ // Second cancel should fail
774
+ let result = rpc.write_call_cancel().await;
775
+ assert!(matches!(result, Err(Error::Completed)));
776
+ }
777
+ }