starpc 0.49.15 → 0.49.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,12 +10,14 @@ export declare class CommonRPC {
10
10
  protected service?: string;
11
11
  protected method?: string;
12
12
  private closed?;
13
+ private readonly writeDrainAbort;
13
14
  constructor();
14
15
  get isClosed(): boolean | Error;
15
16
  writeCallData(data?: Uint8Array, complete?: boolean, error?: string): Promise<void>;
17
+ private writeCallDataPacket;
16
18
  writeCallCancel(): Promise<void>;
17
19
  writeCallDataFromSource(dataSource: AsyncIterable<Uint8Array>): Promise<void>;
18
- protected writePacket(packet: Packet): Promise<void>;
20
+ protected writePacket(packet: Packet, options?: WritePacketOptions): Promise<void>;
19
21
  handleMessage(message: Uint8Array): Promise<void>;
20
22
  handlePacket(packet: Packet): Promise<void>;
21
23
  handleCallStart(packet: Partial<CallStart>): Promise<void>;
@@ -25,3 +27,8 @@ export declare class CommonRPC {
25
27
  close(err?: Error): Promise<void>;
26
28
  private _createSink;
27
29
  }
30
+ interface WritePacketOptions {
31
+ allowClosed?: boolean;
32
+ waitForDrain?: boolean;
33
+ }
34
+ export {};
@@ -1,6 +1,7 @@
1
1
  import { pushable } from 'it-pushable';
2
2
  import { Packet } from './rpcproto.pb.js';
3
3
  import { ERR_RPC_ABORT, RemoteRPCError } from './errors.js';
4
+ const maxBufferedOutgoingPackets = 1;
4
5
  // CommonRPC is common logic between server and client RPCs.
5
6
  export class CommonRPC {
6
7
  // sink is the data sink for incoming messages.
@@ -23,6 +24,8 @@ export class CommonRPC {
23
24
  method;
24
25
  // closed indicates this rpc has been closed already.
25
26
  closed;
27
+ // writeDrainAbort wakes writers waiting for outbound stream drain on close.
28
+ writeDrainAbort = new AbortController();
26
29
  constructor() {
27
30
  this.sink = this._createSink();
28
31
  this.source = this._source;
@@ -34,6 +37,10 @@ export class CommonRPC {
34
37
  }
35
38
  // writeCallData writes the call data packet.
36
39
  async writeCallData(data, complete, error) {
40
+ await this.writeCallDataPacket(data, complete, error);
41
+ }
42
+ // writeCallDataPacket writes a call-data packet with optional drain control.
43
+ async writeCallDataPacket(data, complete, error, writeOptions) {
37
44
  const callData = {
38
45
  data: data || new Uint8Array(0),
39
46
  dataIsZero: !!data && data.length === 0,
@@ -45,7 +52,7 @@ export class CommonRPC {
45
52
  case: 'callData',
46
53
  value: callData,
47
54
  },
48
- });
55
+ }, writeOptions);
49
56
  }
50
57
  // writeCallCancel writes the call cancel packet.
51
58
  async writeCallCancel() {
@@ -54,7 +61,7 @@ export class CommonRPC {
54
61
  case: 'callCancel',
55
62
  value: true,
56
63
  },
57
- });
64
+ }, { waitForDrain: false });
58
65
  }
59
66
  // writeCallDataFromSource writes all call data from the iterable.
60
67
  async writeCallDataFromSource(dataSource) {
@@ -69,8 +76,24 @@ export class CommonRPC {
69
76
  }
70
77
  }
71
78
  // writePacket writes a packet to the stream.
72
- async writePacket(packet) {
79
+ async writePacket(packet, options) {
80
+ if (this.closed && !options?.allowClosed) {
81
+ throw new Error(ERR_RPC_ABORT);
82
+ }
73
83
  this._source.push(packet);
84
+ if (options?.waitForDrain === false ||
85
+ this._source.readableLength <= maxBufferedOutgoingPackets) {
86
+ return;
87
+ }
88
+ try {
89
+ await this._source.onEmpty({ signal: this.writeDrainAbort.signal });
90
+ }
91
+ catch (err) {
92
+ if (this.closed) {
93
+ throw new Error(ERR_RPC_ABORT, { cause: err });
94
+ }
95
+ throw err;
96
+ }
74
97
  }
75
98
  // handleMessage handles an incoming encoded Packet.
76
99
  //
@@ -149,8 +172,12 @@ export class CommonRPC {
149
172
  this.closed = err ?? true;
150
173
  // note: this does nothing if _source is already ended.
151
174
  if (err && err.message) {
152
- await this.writeCallData(undefined, true, err.message);
175
+ await this.writeCallDataPacket(undefined, true, err.message, {
176
+ allowClosed: true,
177
+ waitForDrain: false,
178
+ });
153
179
  }
180
+ this.writeDrainAbort.abort();
154
181
  this._source.end();
155
182
  this._rpcDataSource.end(err);
156
183
  }
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,178 @@
1
+ import { describe, expect, it } from 'vitest';
2
+ import { Client } from './client.js';
3
+ import { CommonRPC } from './common-rpc.js';
4
+ import { Packet } from './rpcproto.pb.js';
5
+ describe('CommonRPC', () => {
6
+ it('backpressures call-data sources until outbound packets drain', async () => {
7
+ const rpc = new CommonRPC();
8
+ const yielded = { count: 0 };
9
+ const secondYield = deferred();
10
+ const thirdYield = deferred();
11
+ const writeDone = rpc.writeCallDataFromSource(chunkSource(yielded, new Map([
12
+ [1, secondYield],
13
+ [2, thirdYield],
14
+ ])));
15
+ await promiseWithTimeout(secondYield.promise, 'second call-data chunk');
16
+ expect(yielded.count).toBe(2);
17
+ await expectPending(thirdYield.promise, 'third call-data chunk');
18
+ expect(yielded.count).toBe(2);
19
+ const received = new Array();
20
+ const readComplete = (async () => {
21
+ for await (const packet of rpc.source) {
22
+ const body = packet.body;
23
+ if (body?.case !== 'callData') {
24
+ continue;
25
+ }
26
+ if (body.value.complete) {
27
+ return;
28
+ }
29
+ if (!body.value.data) {
30
+ throw new Error('call data packet missing data');
31
+ }
32
+ received.push(body.value.data[0]);
33
+ }
34
+ throw new Error('outbound call data ended before completion');
35
+ })();
36
+ await promiseWithTimeout(Promise.all([writeDone, readComplete]), 'backpressured call data drain');
37
+ await rpc.close();
38
+ expect(yielded.count).toBe(32);
39
+ expect(received).toHaveLength(32);
40
+ expect(received[0]).toBe(0);
41
+ expect(received[31]).toBe(31);
42
+ });
43
+ it('wakes call-data writers when the rpc closes while waiting for drain', async () => {
44
+ const rpc = new CommonRPC();
45
+ const yielded = { count: 0 };
46
+ const secondYield = deferred();
47
+ const thirdYield = deferred();
48
+ const writeDone = rpc.writeCallDataFromSource(chunkSource(yielded, new Map([
49
+ [1, secondYield],
50
+ [2, thirdYield],
51
+ ])));
52
+ await promiseWithTimeout(secondYield.promise, 'second call-data chunk');
53
+ expect(yielded.count).toBe(2);
54
+ await expectPending(thirdYield.promise, 'third call-data chunk');
55
+ expect(yielded.count).toBe(2);
56
+ await rpc.close();
57
+ await expect(promiseWithTimeout(writeDone, 'call data writer close')).resolves.toBeUndefined();
58
+ expect(yielded.count).toBeLessThan(32);
59
+ });
60
+ it('wakes call-data writers when an error closes the rpc while waiting for drain', async () => {
61
+ const rpc = new CommonRPC();
62
+ const yielded = { count: 0 };
63
+ const secondYield = deferred();
64
+ const thirdYield = deferred();
65
+ const writeDone = rpc.writeCallDataFromSource(chunkSource(yielded, new Map([
66
+ [1, secondYield],
67
+ [2, thirdYield],
68
+ ])));
69
+ await promiseWithTimeout(secondYield.promise, 'second call-data chunk');
70
+ expect(yielded.count).toBe(2);
71
+ await expectPending(thirdYield.promise, 'third call-data chunk');
72
+ expect(yielded.count).toBe(2);
73
+ await rpc.close(new Error('boom'));
74
+ await expect(promiseWithTimeout(writeDone, 'call data writer error close')).resolves.toBeUndefined();
75
+ expect(yielded.count).toBeLessThan(32);
76
+ });
77
+ it('does not backpressure call-cancel packets behind queued call data', async () => {
78
+ const rpc = new CommonRPC();
79
+ const yielded = { count: 0 };
80
+ const secondYield = deferred();
81
+ const thirdYield = deferred();
82
+ const writeDone = rpc.writeCallDataFromSource(chunkSource(yielded, new Map([
83
+ [1, secondYield],
84
+ [2, thirdYield],
85
+ ])));
86
+ await promiseWithTimeout(secondYield.promise, 'second call-data chunk');
87
+ expect(yielded.count).toBe(2);
88
+ await expectPending(thirdYield.promise, 'third call-data chunk');
89
+ expect(yielded.count).toBe(2);
90
+ const cancelDone = rpc.writeCallCancel();
91
+ await rpc.close(new Error('abort'));
92
+ await expect(promiseWithTimeout(cancelDone, 'call cancel write')).resolves.toBeUndefined();
93
+ await expect(promiseWithTimeout(writeDone, 'blocked call data close')).resolves.toBeUndefined();
94
+ });
95
+ it('backpressures client-stream sources through the Client request path', async () => {
96
+ const yielded = { count: 0 };
97
+ const firstYield = deferred();
98
+ const secondYield = deferred();
99
+ const sinkConsumed = { count: 0 };
100
+ const sinkGate = deferred();
101
+ const responseGate = deferred();
102
+ const response = new Uint8Array([7]);
103
+ const client = new Client(async () => ({
104
+ source: (async function* () {
105
+ await responseGate.promise;
106
+ yield Packet.toBinary({
107
+ body: {
108
+ case: 'callData',
109
+ value: {
110
+ data: response,
111
+ dataIsZero: false,
112
+ complete: false,
113
+ error: '',
114
+ },
115
+ },
116
+ });
117
+ })(),
118
+ sink: async (source) => {
119
+ await sinkGate.promise;
120
+ for await (const _packet of source) {
121
+ sinkConsumed.count++;
122
+ }
123
+ },
124
+ }));
125
+ const request = client.clientStreamingRequest('test.Service', 'Upload', chunkSource(yielded, new Map([
126
+ [0, firstYield],
127
+ [1, secondYield],
128
+ ])));
129
+ await promiseWithTimeout(firstYield.promise, 'first upload chunk');
130
+ expect(yielded.count).toBe(1);
131
+ await expectPending(secondYield.promise, 'second upload chunk');
132
+ expect(yielded.count).toBe(1);
133
+ expect(sinkConsumed.count).toBe(0);
134
+ await expectPending(request, 'client-stream request');
135
+ sinkGate.resolve();
136
+ responseGate.resolve();
137
+ await expect(promiseWithTimeout(request, 'client-stream response')).resolves.toEqual(response);
138
+ });
139
+ });
140
+ async function* chunkSource(yielded, yieldMarks) {
141
+ for (const i of Array.from({ length: 32 }, (_, index) => index)) {
142
+ yielded.count++;
143
+ yieldMarks?.get(i)?.resolve();
144
+ yield new Uint8Array([i]);
145
+ }
146
+ }
147
+ function promiseWithTimeout(promise, label) {
148
+ return Promise.race([
149
+ promise,
150
+ new Promise((_, reject) => {
151
+ setTimeout(() => reject(new Error(`timed out waiting for ${label}`)), 500);
152
+ }),
153
+ ]);
154
+ }
155
+ async function expectPending(promise, label) {
156
+ await expect(Promise.race([
157
+ promise.then(() => 'settled'),
158
+ new Promise((resolve) => {
159
+ setTimeout(() => resolve('pending'), 50);
160
+ }),
161
+ ]), `${label} should not be requested while outbound packets are blocked`).resolves.toBe('pending');
162
+ }
163
+ function deferred() {
164
+ const callbacks = {};
165
+ const promise = new Promise((resolve, reject) => {
166
+ callbacks.resolve = resolve;
167
+ callbacks.reject = reject;
168
+ });
169
+ return {
170
+ promise,
171
+ resolve() {
172
+ callbacks.resolve?.();
173
+ },
174
+ reject(reason) {
175
+ callbacks.reject?.(reason);
176
+ },
177
+ };
178
+ }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "starpc",
3
- "version": "0.49.15",
3
+ "version": "0.49.16",
4
4
  "description": "Streaming protobuf RPC service protocol over any two-way channel.",
5
5
  "license": "MIT",
6
6
  "author": {
@@ -31,7 +31,7 @@ func (r *ClientRPC) Start(writer PacketWriter, writeFirstMsg bool, firstMsg []by
31
31
  }
32
32
 
33
33
  if err := r.ctx.Err(); err != nil {
34
- r.ctxCancel()
34
+ r.cancelContext()
35
35
  _ = writer.Close()
36
36
  return context.Canceled
37
37
  }
@@ -48,7 +48,7 @@ func (r *ClientRPC) Start(writer PacketWriter, writeFirstMsg bool, firstMsg []by
48
48
  pkt := NewCallStartPacket(r.service, r.method, firstMsg, firstMsgEmpty)
49
49
  err = writer.WritePacket(pkt)
50
50
  if err != nil {
51
- r.ctxCancel()
51
+ r.cancelContext()
52
52
  _ = writer.Close()
53
53
  }
54
54
 
@@ -74,7 +74,7 @@ func (r *ClientRPC) HandleStreamClose(closeErr error) {
74
74
  r.remoteErr = closeErr
75
75
  }
76
76
  r.dataClosed = true
77
- r.ctxCancel()
77
+ r.cancelContext()
78
78
  locked.Broadcast()
79
79
  locked.Unlock()
80
80
  }
package/srpc/client.go CHANGED
@@ -91,7 +91,7 @@ func (c *client) NewStream(ctx context.Context, service, method string, firstMsg
91
91
  }
92
92
 
93
93
  return NewMsgStream(ctx, clientRPC, func() {
94
- clientRPC.ctxCancel()
94
+ clientRPC.cancelContext()
95
95
  _ = writer.Close()
96
96
  }), nil
97
97
  }
@@ -5,16 +5,19 @@ import (
5
5
  "io"
6
6
  "sync/atomic"
7
7
 
8
+ "github.com/aperturerobotics/starpc/internal/contextutil"
8
9
  "github.com/aperturerobotics/util/broadcast"
9
10
  "github.com/pkg/errors"
10
11
  )
11
12
 
12
13
  // commonRPC contains common logic between server/client rpc.
13
14
  type commonRPC struct {
14
- // ctx is the context, canceled when the rpc ends.
15
+ // ctx is the RPC context, canceled when the RPC is canceled.
15
16
  ctx context.Context
16
- // ctxCancel is called when the rpc ends.
17
+ // ctxCancel cancels ctx.
17
18
  ctxCancel context.CancelFunc
19
+ // ctxCanceled tracks whether ctxCancel has already been called.
20
+ ctxCanceled atomic.Bool
18
21
  // service is the rpc service
19
22
  service string
20
23
  // method is the rpc method
@@ -28,6 +31,11 @@ type commonRPC struct {
28
31
  writer PacketWriter
29
32
  // writerClosed is set after writer has been closed locally.
30
33
  writerClosed bool
34
+ // localCompleting is set while the local handler is publishing its terminal
35
+ // packet and closing the writer.
36
+ localCompleting bool
37
+ // localDone is set after the local handler has completed normally.
38
+ localDone bool
31
39
  // dataQueue contains incoming data packets.
32
40
  // note: packets may be len() == 0
33
41
  dataQueue [][]byte
@@ -40,7 +48,14 @@ type commonRPC struct {
40
48
 
41
49
  // initCommonRPC initializes the commonRPC.
42
50
  func initCommonRPC(ctx context.Context, rpc *commonRPC) {
43
- rpc.ctx, rpc.ctxCancel = context.WithCancel(ctx)
51
+ rpc.ctx, rpc.ctxCancel = contextutil.WithCancel(ctx)
52
+ }
53
+
54
+ func (c *commonRPC) cancelContext() {
55
+ if c.ctxCanceled.Swap(true) {
56
+ return
57
+ }
58
+ c.ctxCancel()
44
59
  }
45
60
 
46
61
  // Context is canceled when the rpc has finished.
@@ -53,10 +68,13 @@ func (c *commonRPC) Wait(ctx context.Context) error {
53
68
  for {
54
69
  var err error
55
70
  var waitCh <-chan struct{}
56
- var rpcCtx context.Context
71
+ var rpcCanceled bool
72
+ var localDone bool
57
73
  locked := c.bcast.Lock()
58
- rpcCtx, err = c.ctx, c.remoteErr
59
- if err == nil && rpcCtx.Err() == nil {
74
+ err = c.remoteErr
75
+ rpcCanceled = c.ctx.Err() != nil
76
+ localDone = c.localDone
77
+ if err == nil && !rpcCanceled && !localDone {
60
78
  waitCh = locked.WaitCh()
61
79
  }
62
80
  locked.Unlock()
@@ -64,15 +82,16 @@ func (c *commonRPC) Wait(ctx context.Context) error {
64
82
  if err != nil {
65
83
  return err
66
84
  }
67
- if rpcCtx.Err() != nil {
68
- // rpc must have ended w/o an error being set
85
+ if localDone {
86
+ return nil
87
+ }
88
+ if rpcCanceled {
69
89
  return context.Canceled
70
90
  }
71
91
 
72
92
  select {
73
93
  case <-ctx.Done():
74
94
  return context.Canceled
75
- case <-rpcCtx.Done():
76
95
  case <-waitCh:
77
96
  }
78
97
  }
@@ -153,12 +172,15 @@ func (c *commonRPC) HandleStreamClose(closeErr error) {
153
172
  locked.Unlock()
154
173
  return
155
174
  }
175
+ normalRemoteCloseAfterLocalComplete := closeErr == nil && (c.localCompleting || c.localDone)
156
176
  if closeErr != nil && c.remoteErr == nil {
157
177
  c.remoteErr = closeErr
158
178
  }
159
179
  c.dataClosed = true
160
- c.ctxCancel()
161
- writer = c.closeWriterLocked()
180
+ if !normalRemoteCloseAfterLocalComplete {
181
+ c.cancelContext()
182
+ writer = c.closeWriterLocked()
183
+ }
162
184
  locked.Broadcast()
163
185
  locked.Unlock()
164
186
  if writer != nil {
@@ -225,7 +247,7 @@ func (c *commonRPC) closeLocked(locked *broadcast.Locked) PacketWriter {
225
247
  }
226
248
  writer := c.closeWriterLocked()
227
249
  locked.Broadcast()
228
- c.ctxCancel()
250
+ c.cancelContext()
229
251
  return writer
230
252
  }
231
253
 
@@ -236,3 +258,25 @@ func (c *commonRPC) closeWriterLocked() PacketWriter {
236
258
  c.writerClosed = true
237
259
  return c.writer
238
260
  }
261
+
262
+ func (c *commonRPC) beginLocalCompletion() {
263
+ locked := c.bcast.Lock()
264
+ c.localCompleted.Store(true)
265
+ c.localCompleting = true
266
+ locked.Unlock()
267
+ }
268
+
269
+ func (c *commonRPC) finishLocalCompletion() {
270
+ locked := c.bcast.Lock()
271
+ c.localCompleted.Store(true)
272
+ writer := c.closeWriterLocked()
273
+ locked.Unlock()
274
+ if writer != nil {
275
+ _ = writer.Close()
276
+ }
277
+ locked = c.bcast.Lock()
278
+ c.localCompleting = false
279
+ c.localDone = true
280
+ locked.Broadcast()
281
+ locked.Unlock()
282
+ }
@@ -0,0 +1,256 @@
1
+ import { describe, expect, it } from 'vitest'
2
+
3
+ import { Client } from './client.js'
4
+ import { CommonRPC } from './common-rpc.js'
5
+ import { Packet } from './rpcproto.pb.js'
6
+
7
+ describe('CommonRPC', () => {
8
+ it('backpressures call-data sources until outbound packets drain', async () => {
9
+ const rpc = new CommonRPC()
10
+ const yielded = { count: 0 }
11
+ const secondYield = deferred()
12
+ const thirdYield = deferred()
13
+
14
+ const writeDone = rpc.writeCallDataFromSource(
15
+ chunkSource(
16
+ yielded,
17
+ new Map([
18
+ [1, secondYield],
19
+ [2, thirdYield],
20
+ ]),
21
+ ),
22
+ )
23
+
24
+ await promiseWithTimeout(secondYield.promise, 'second call-data chunk')
25
+ expect(yielded.count).toBe(2)
26
+ await expectPending(thirdYield.promise, 'third call-data chunk')
27
+ expect(yielded.count).toBe(2)
28
+
29
+ const received = new Array<number>()
30
+ const readComplete = (async () => {
31
+ for await (const packet of rpc.source) {
32
+ const body = packet.body
33
+ if (body?.case !== 'callData') {
34
+ continue
35
+ }
36
+ if (body.value.complete) {
37
+ return
38
+ }
39
+ if (!body.value.data) {
40
+ throw new Error('call data packet missing data')
41
+ }
42
+ received.push(body.value.data[0])
43
+ }
44
+ throw new Error('outbound call data ended before completion')
45
+ })()
46
+
47
+ await promiseWithTimeout(
48
+ Promise.all([writeDone, readComplete]),
49
+ 'backpressured call data drain',
50
+ )
51
+ await rpc.close()
52
+
53
+ expect(yielded.count).toBe(32)
54
+ expect(received).toHaveLength(32)
55
+ expect(received[0]).toBe(0)
56
+ expect(received[31]).toBe(31)
57
+ })
58
+
59
+ it('wakes call-data writers when the rpc closes while waiting for drain', async () => {
60
+ const rpc = new CommonRPC()
61
+ const yielded = { count: 0 }
62
+ const secondYield = deferred()
63
+ const thirdYield = deferred()
64
+
65
+ const writeDone = rpc.writeCallDataFromSource(
66
+ chunkSource(
67
+ yielded,
68
+ new Map([
69
+ [1, secondYield],
70
+ [2, thirdYield],
71
+ ]),
72
+ ),
73
+ )
74
+
75
+ await promiseWithTimeout(secondYield.promise, 'second call-data chunk')
76
+ expect(yielded.count).toBe(2)
77
+ await expectPending(thirdYield.promise, 'third call-data chunk')
78
+ expect(yielded.count).toBe(2)
79
+
80
+ await rpc.close()
81
+ await expect(
82
+ promiseWithTimeout(writeDone, 'call data writer close'),
83
+ ).resolves.toBeUndefined()
84
+ expect(yielded.count).toBeLessThan(32)
85
+ })
86
+
87
+ it('wakes call-data writers when an error closes the rpc while waiting for drain', async () => {
88
+ const rpc = new CommonRPC()
89
+ const yielded = { count: 0 }
90
+ const secondYield = deferred()
91
+ const thirdYield = deferred()
92
+
93
+ const writeDone = rpc.writeCallDataFromSource(
94
+ chunkSource(
95
+ yielded,
96
+ new Map([
97
+ [1, secondYield],
98
+ [2, thirdYield],
99
+ ]),
100
+ ),
101
+ )
102
+
103
+ await promiseWithTimeout(secondYield.promise, 'second call-data chunk')
104
+ expect(yielded.count).toBe(2)
105
+ await expectPending(thirdYield.promise, 'third call-data chunk')
106
+ expect(yielded.count).toBe(2)
107
+
108
+ await rpc.close(new Error('boom'))
109
+ await expect(
110
+ promiseWithTimeout(writeDone, 'call data writer error close'),
111
+ ).resolves.toBeUndefined()
112
+ expect(yielded.count).toBeLessThan(32)
113
+ })
114
+
115
+ it('does not backpressure call-cancel packets behind queued call data', async () => {
116
+ const rpc = new CommonRPC()
117
+ const yielded = { count: 0 }
118
+ const secondYield = deferred()
119
+ const thirdYield = deferred()
120
+ const writeDone = rpc.writeCallDataFromSource(
121
+ chunkSource(
122
+ yielded,
123
+ new Map([
124
+ [1, secondYield],
125
+ [2, thirdYield],
126
+ ]),
127
+ ),
128
+ )
129
+
130
+ await promiseWithTimeout(secondYield.promise, 'second call-data chunk')
131
+ expect(yielded.count).toBe(2)
132
+ await expectPending(thirdYield.promise, 'third call-data chunk')
133
+ expect(yielded.count).toBe(2)
134
+
135
+ const cancelDone = rpc.writeCallCancel()
136
+ await rpc.close(new Error('abort'))
137
+
138
+ await expect(
139
+ promiseWithTimeout(cancelDone, 'call cancel write'),
140
+ ).resolves.toBeUndefined()
141
+ await expect(
142
+ promiseWithTimeout(writeDone, 'blocked call data close'),
143
+ ).resolves.toBeUndefined()
144
+ })
145
+
146
+ it('backpressures client-stream sources through the Client request path', async () => {
147
+ const yielded = { count: 0 }
148
+ const firstYield = deferred()
149
+ const secondYield = deferred()
150
+ const sinkConsumed = { count: 0 }
151
+ const sinkGate = deferred()
152
+ const responseGate = deferred()
153
+ const response = new Uint8Array([7])
154
+ const client = new Client(async () => ({
155
+ source: (async function* () {
156
+ await responseGate.promise
157
+ yield Packet.toBinary({
158
+ body: {
159
+ case: 'callData',
160
+ value: {
161
+ data: response,
162
+ dataIsZero: false,
163
+ complete: false,
164
+ error: '',
165
+ },
166
+ },
167
+ })
168
+ })(),
169
+ sink: async (source) => {
170
+ await sinkGate.promise
171
+ for await (const _packet of source) {
172
+ sinkConsumed.count++
173
+ }
174
+ },
175
+ }))
176
+
177
+ const request = client.clientStreamingRequest(
178
+ 'test.Service',
179
+ 'Upload',
180
+ chunkSource(
181
+ yielded,
182
+ new Map([
183
+ [0, firstYield],
184
+ [1, secondYield],
185
+ ]),
186
+ ),
187
+ )
188
+
189
+ await promiseWithTimeout(firstYield.promise, 'first upload chunk')
190
+ expect(yielded.count).toBe(1)
191
+ await expectPending(secondYield.promise, 'second upload chunk')
192
+ expect(yielded.count).toBe(1)
193
+ expect(sinkConsumed.count).toBe(0)
194
+ await expectPending(request, 'client-stream request')
195
+
196
+ sinkGate.resolve()
197
+ responseGate.resolve()
198
+ await expect(
199
+ promiseWithTimeout(request, 'client-stream response'),
200
+ ).resolves.toEqual(response)
201
+ })
202
+ })
203
+
204
+ async function* chunkSource(
205
+ yielded: { count: number },
206
+ yieldMarks?: Map<number, Deferred>,
207
+ ) {
208
+ for (const i of Array.from({ length: 32 }, (_, index) => index)) {
209
+ yielded.count++
210
+ yieldMarks?.get(i)?.resolve()
211
+ yield new Uint8Array([i])
212
+ }
213
+ }
214
+
215
+ function promiseWithTimeout<T>(promise: Promise<T>, label: string): Promise<T> {
216
+ return Promise.race([
217
+ promise,
218
+ new Promise<T>((_, reject) => {
219
+ setTimeout(() => reject(new Error(`timed out waiting for ${label}`)), 500)
220
+ }),
221
+ ])
222
+ }
223
+
224
+ async function expectPending<T>(promise: Promise<T>, label: string) {
225
+ await expect(
226
+ Promise.race([
227
+ promise.then(() => 'settled'),
228
+ new Promise<'pending'>((resolve) => {
229
+ setTimeout(() => resolve('pending'), 50)
230
+ }),
231
+ ]),
232
+ `${label} should not be requested while outbound packets are blocked`,
233
+ ).resolves.toBe('pending')
234
+ }
235
+
236
+ type Deferred = ReturnType<typeof deferred>
237
+
238
+ function deferred() {
239
+ const callbacks: {
240
+ resolve?: () => void
241
+ reject?: (reason?: unknown) => void
242
+ } = {}
243
+ const promise = new Promise<void>((resolve, reject) => {
244
+ callbacks.resolve = resolve
245
+ callbacks.reject = reject
246
+ })
247
+ return {
248
+ promise,
249
+ resolve() {
250
+ callbacks.resolve?.()
251
+ },
252
+ reject(reason?: unknown) {
253
+ callbacks.reject?.(reason)
254
+ },
255
+ }
256
+ }
@@ -1,11 +1,13 @@
1
1
  import type { Sink, Source } from 'it-stream-types'
2
- import { pushable } from 'it-pushable'
2
+ import { pushable, type Pushable } from 'it-pushable'
3
3
  import { CompleteMessage } from '@aptre/protobuf-es-lite'
4
4
 
5
5
  import type { CallData, CallStart } from './rpcproto.pb.js'
6
6
  import { Packet } from './rpcproto.pb.js'
7
7
  import { ERR_RPC_ABORT, RemoteRPCError } from './errors.js'
8
8
 
9
+ const maxBufferedOutgoingPackets = 1
10
+
9
11
  // CommonRPC is common logic between server and client RPCs.
10
12
  export class CommonRPC {
11
13
  // sink is the data sink for incoming messages.
@@ -16,7 +18,7 @@ export class CommonRPC {
16
18
  public readonly rpcDataSource: AsyncIterable<Uint8Array>
17
19
 
18
20
  // _source is used to write to the source.
19
- private readonly _source = pushable<Packet>({
21
+ private readonly _source: Pushable<Packet> = pushable<Packet>({
20
22
  objectMode: true,
21
23
  })
22
24
 
@@ -32,6 +34,8 @@ export class CommonRPC {
32
34
 
33
35
  // closed indicates this rpc has been closed already.
34
36
  private closed?: true | Error
37
+ // writeDrainAbort wakes writers waiting for outbound stream drain on close.
38
+ private readonly writeDrainAbort = new AbortController()
35
39
 
36
40
  constructor() {
37
41
  this.sink = this._createSink()
@@ -49,6 +53,16 @@ export class CommonRPC {
49
53
  data?: Uint8Array,
50
54
  complete?: boolean,
51
55
  error?: string,
56
+ ) {
57
+ await this.writeCallDataPacket(data, complete, error)
58
+ }
59
+
60
+ // writeCallDataPacket writes a call-data packet with optional drain control.
61
+ private async writeCallDataPacket(
62
+ data?: Uint8Array,
63
+ complete?: boolean,
64
+ error?: string,
65
+ writeOptions?: WritePacketOptions,
52
66
  ) {
53
67
  const callData: CompleteMessage<CallData> = {
54
68
  data: data || new Uint8Array(0),
@@ -56,22 +70,28 @@ export class CommonRPC {
56
70
  complete: complete || false,
57
71
  error: error || '',
58
72
  }
59
- await this.writePacket({
60
- body: {
61
- case: 'callData',
62
- value: callData,
73
+ await this.writePacket(
74
+ {
75
+ body: {
76
+ case: 'callData',
77
+ value: callData,
78
+ },
63
79
  },
64
- })
80
+ writeOptions,
81
+ )
65
82
  }
66
83
 
67
84
  // writeCallCancel writes the call cancel packet.
68
85
  public async writeCallCancel() {
69
- await this.writePacket({
70
- body: {
71
- case: 'callCancel',
72
- value: true,
86
+ await this.writePacket(
87
+ {
88
+ body: {
89
+ case: 'callCancel',
90
+ value: true,
91
+ },
73
92
  },
74
- })
93
+ { waitForDrain: false },
94
+ )
75
95
  }
76
96
 
77
97
  // writeCallDataFromSource writes all call data from the iterable.
@@ -87,8 +107,25 @@ export class CommonRPC {
87
107
  }
88
108
 
89
109
  // writePacket writes a packet to the stream.
90
- protected async writePacket(packet: Packet) {
110
+ protected async writePacket(packet: Packet, options?: WritePacketOptions) {
111
+ if (this.closed && !options?.allowClosed) {
112
+ throw new Error(ERR_RPC_ABORT)
113
+ }
91
114
  this._source.push(packet)
115
+ if (
116
+ options?.waitForDrain === false ||
117
+ this._source.readableLength <= maxBufferedOutgoingPackets
118
+ ) {
119
+ return
120
+ }
121
+ try {
122
+ await this._source.onEmpty({ signal: this.writeDrainAbort.signal })
123
+ } catch (err) {
124
+ if (this.closed) {
125
+ throw new Error(ERR_RPC_ABORT, { cause: err })
126
+ }
127
+ throw err
128
+ }
92
129
  }
93
130
 
94
131
  // handleMessage handles an incoming encoded Packet.
@@ -179,8 +216,12 @@ export class CommonRPC {
179
216
  this.closed = err ?? true
180
217
  // note: this does nothing if _source is already ended.
181
218
  if (err && err.message) {
182
- await this.writeCallData(undefined, true, err.message)
219
+ await this.writeCallDataPacket(undefined, true, err.message, {
220
+ allowClosed: true,
221
+ waitForDrain: false,
222
+ })
183
223
  }
224
+ this.writeDrainAbort.abort()
184
225
  this._source.end()
185
226
  this._rpcDataSource.end(err)
186
227
  }
@@ -206,3 +247,8 @@ export class CommonRPC {
206
247
  }
207
248
  }
208
249
  }
250
+
251
+ interface WritePacketOptions {
252
+ allowClosed?: boolean
253
+ waitForDrain?: boolean
254
+ }
@@ -4,8 +4,10 @@ import (
4
4
  "bytes"
5
5
  "context"
6
6
  "io"
7
+ "sync"
7
8
  "sync/atomic"
8
9
  "testing"
10
+ "time"
9
11
  )
10
12
 
11
13
  type closeCountingPacketWriter struct {
@@ -16,6 +18,19 @@ type closeCallbackPacketWriter struct {
16
18
  closeFn func()
17
19
  }
18
20
 
21
+ type packetRecordingWriter struct {
22
+ packets chan *Packet
23
+ closed chan struct{}
24
+ once sync.Once
25
+ }
26
+
27
+ type blockingPacketWriter struct {
28
+ packets chan *Packet
29
+ releaseWrite chan struct{}
30
+ closed chan struct{}
31
+ once sync.Once
32
+ }
33
+
19
34
  func (w *closeCountingPacketWriter) WritePacket(*Packet) error {
20
35
  return nil
21
36
  }
@@ -36,6 +51,46 @@ func (w *closeCallbackPacketWriter) Close() error {
36
51
  return nil
37
52
  }
38
53
 
54
+ func newPacketRecordingWriter() *packetRecordingWriter {
55
+ return &packetRecordingWriter{
56
+ packets: make(chan *Packet, 1),
57
+ closed: make(chan struct{}),
58
+ }
59
+ }
60
+
61
+ func (w *packetRecordingWriter) WritePacket(pkt *Packet) error {
62
+ w.packets <- pkt
63
+ return nil
64
+ }
65
+
66
+ func (w *packetRecordingWriter) Close() error {
67
+ w.once.Do(func() {
68
+ close(w.closed)
69
+ })
70
+ return nil
71
+ }
72
+
73
+ func newBlockingPacketWriter() *blockingPacketWriter {
74
+ return &blockingPacketWriter{
75
+ packets: make(chan *Packet, 1),
76
+ releaseWrite: make(chan struct{}),
77
+ closed: make(chan struct{}),
78
+ }
79
+ }
80
+
81
+ func (w *blockingPacketWriter) WritePacket(pkt *Packet) error {
82
+ w.packets <- pkt
83
+ <-w.releaseWrite
84
+ return nil
85
+ }
86
+
87
+ func (w *blockingPacketWriter) Close() error {
88
+ w.once.Do(func() {
89
+ close(w.closed)
90
+ })
91
+ return nil
92
+ }
93
+
39
94
  func TestCommonRPCHandleStreamCloseIdempotent(t *testing.T) {
40
95
  writer := &closeCountingPacketWriter{}
41
96
  rpc := NewServerRPC(context.Background(), InvokerFunc(nil), writer)
@@ -48,6 +103,23 @@ func TestCommonRPCHandleStreamCloseIdempotent(t *testing.T) {
48
103
  }
49
104
  }
50
105
 
106
+ func TestCommonRPCCancelContextIdempotent(t *testing.T) {
107
+ var calls atomic.Int32
108
+ rpc := &commonRPC{
109
+ ctx: context.Background(),
110
+ ctxCancel: func() {
111
+ calls.Add(1)
112
+ },
113
+ }
114
+
115
+ rpc.cancelContext()
116
+ rpc.cancelContext()
117
+
118
+ if got := calls.Load(); got != 1 {
119
+ t.Fatalf("expected context cancel once, got %d", got)
120
+ }
121
+ }
122
+
51
123
  func TestCommonRPCHandleStreamCloseClosesWriterOutsideBroadcastLock(t *testing.T) {
52
124
  var rpc *ServerRPC
53
125
  writerClosedOutsideLock := false
@@ -69,6 +141,176 @@ func TestCommonRPCHandleStreamCloseClosesWriterOutsideBroadcastLock(t *testing.T
69
141
  }
70
142
  }
71
143
 
144
+ func TestServerRPCWaitReturnsAfterLocalInvokeCompletion(t *testing.T) {
145
+ writer := newPacketRecordingWriter()
146
+ streamCtxCh := make(chan context.Context, 1)
147
+ invokeErrCh := make(chan string, 1)
148
+ rpc := NewServerRPC(context.Background(), InvokerFunc(func(serviceID, methodID string, strm Stream) (bool, error) {
149
+ if serviceID != "service" || methodID != "method" {
150
+ invokeErrCh <- serviceID + "/" + methodID
151
+ }
152
+ streamCtxCh <- strm.Context()
153
+ return true, nil
154
+ }), writer)
155
+
156
+ if err := rpc.HandleCallStart(NewCallStartPacket("service", "method", nil, false).GetCallStart()); err != nil {
157
+ t.Fatalf("handle call start: %v", err)
158
+ }
159
+
160
+ waitCtx, waitCancel := context.WithTimeout(context.Background(), time.Second)
161
+ defer waitCancel()
162
+ if err := rpc.Wait(waitCtx); err != nil {
163
+ t.Fatalf("wait: %v", err)
164
+ }
165
+
166
+ select {
167
+ case got := <-invokeErrCh:
168
+ t.Fatalf("unexpected invoke target: %s", got)
169
+ default:
170
+ }
171
+
172
+ var streamCtx context.Context
173
+ select {
174
+ case streamCtx = <-streamCtxCh:
175
+ case <-time.After(time.Second):
176
+ t.Fatal("invoker did not receive stream context")
177
+ }
178
+ if err := streamCtx.Err(); err != nil {
179
+ t.Fatalf("normal invoke completion canceled stream context: %v", err)
180
+ }
181
+
182
+ select {
183
+ case pkt := <-writer.packets:
184
+ callData := pkt.GetCallData()
185
+ if callData == nil || !callData.GetComplete() || callData.GetError() != "" {
186
+ t.Fatalf("expected successful completion packet, got %#v", pkt.GetBody())
187
+ }
188
+ default:
189
+ t.Fatal("expected completion packet")
190
+ }
191
+
192
+ select {
193
+ case <-writer.closed:
194
+ default:
195
+ t.Fatal("expected writer closed")
196
+ }
197
+ }
198
+
199
+ func TestServerRPCRemoteCloseAfterLocalCompletionDoesNotCancelStreamContext(t *testing.T) {
200
+ writer := newPacketRecordingWriter()
201
+ streamCtxCh := make(chan context.Context, 1)
202
+ rpc := NewServerRPC(context.Background(), InvokerFunc(func(serviceID, methodID string, strm Stream) (bool, error) {
203
+ streamCtxCh <- strm.Context()
204
+ return true, nil
205
+ }), writer)
206
+
207
+ if err := rpc.HandleCallStart(NewCallStartPacket("service", "method", nil, false).GetCallStart()); err != nil {
208
+ t.Fatalf("handle call start: %v", err)
209
+ }
210
+
211
+ var streamCtx context.Context
212
+ select {
213
+ case streamCtx = <-streamCtxCh:
214
+ case <-time.After(time.Second):
215
+ t.Fatal("invoker did not receive stream context")
216
+ }
217
+ select {
218
+ case <-writer.closed:
219
+ case <-time.After(time.Second):
220
+ t.Fatal("writer did not close")
221
+ }
222
+
223
+ rpc.HandleStreamClose(nil)
224
+
225
+ if err := streamCtx.Err(); err != nil {
226
+ t.Fatalf("normal remote close after local completion canceled stream context: %v", err)
227
+ }
228
+ waitCtx, waitCancel := context.WithTimeout(context.Background(), time.Second)
229
+ defer waitCancel()
230
+ if err := rpc.Wait(waitCtx); err != nil {
231
+ t.Fatalf("wait: %v", err)
232
+ }
233
+ }
234
+
235
+ func TestServerRPCRemoteCloseDuringLocalCompletionDoesNotCancelStreamContext(t *testing.T) {
236
+ writer := newBlockingPacketWriter()
237
+ streamCtxCh := make(chan context.Context, 1)
238
+ rpc := NewServerRPC(context.Background(), InvokerFunc(func(serviceID, methodID string, strm Stream) (bool, error) {
239
+ streamCtxCh <- strm.Context()
240
+ return true, nil
241
+ }), writer)
242
+
243
+ if err := rpc.HandleCallStart(NewCallStartPacket("service", "method", nil, false).GetCallStart()); err != nil {
244
+ t.Fatalf("handle call start: %v", err)
245
+ }
246
+
247
+ var streamCtx context.Context
248
+ select {
249
+ case streamCtx = <-streamCtxCh:
250
+ case <-time.After(time.Second):
251
+ t.Fatal("invoker did not receive stream context")
252
+ }
253
+ select {
254
+ case <-writer.packets:
255
+ case <-time.After(time.Second):
256
+ t.Fatal("terminal packet write did not start")
257
+ }
258
+
259
+ rpc.HandleStreamClose(nil)
260
+
261
+ if err := streamCtx.Err(); err != nil {
262
+ t.Fatalf("normal remote close during local completion canceled stream context: %v", err)
263
+ }
264
+ close(writer.releaseWrite)
265
+ select {
266
+ case <-writer.closed:
267
+ case <-time.After(time.Second):
268
+ t.Fatal("writer did not close")
269
+ }
270
+ waitCtx, waitCancel := context.WithTimeout(context.Background(), time.Second)
271
+ defer waitCancel()
272
+ if err := rpc.Wait(waitCtx); err != nil {
273
+ t.Fatalf("wait: %v", err)
274
+ }
275
+ }
276
+
277
+ func TestServerRPCRemoteErrorDuringLocalCompletionCancelsStreamContext(t *testing.T) {
278
+ writer := newBlockingPacketWriter()
279
+ streamCtxCh := make(chan context.Context, 1)
280
+ rpc := NewServerRPC(context.Background(), InvokerFunc(func(serviceID, methodID string, strm Stream) (bool, error) {
281
+ streamCtxCh <- strm.Context()
282
+ return true, nil
283
+ }), writer)
284
+
285
+ if err := rpc.HandleCallStart(NewCallStartPacket("service", "method", nil, false).GetCallStart()); err != nil {
286
+ t.Fatalf("handle call start: %v", err)
287
+ }
288
+
289
+ var streamCtx context.Context
290
+ select {
291
+ case streamCtx = <-streamCtxCh:
292
+ case <-time.After(time.Second):
293
+ t.Fatal("invoker did not receive stream context")
294
+ }
295
+ select {
296
+ case <-writer.packets:
297
+ case <-time.After(time.Second):
298
+ t.Fatal("terminal packet write did not start")
299
+ }
300
+
301
+ rpc.HandleStreamClose(context.Canceled)
302
+
303
+ if err := streamCtx.Err(); err != context.Canceled {
304
+ t.Fatalf("remote error during local completion context error = %v want %v", err, context.Canceled)
305
+ }
306
+ close(writer.releaseWrite)
307
+ waitCtx, waitCancel := context.WithTimeout(context.Background(), time.Second)
308
+ defer waitCancel()
309
+ if err := rpc.Wait(waitCtx); err != context.Canceled {
310
+ t.Fatalf("wait = %v want %v", err, context.Canceled)
311
+ }
312
+ }
313
+
72
314
  func TestClientRPCCloseAfterRemoteCompleteClosesWriter(t *testing.T) {
73
315
  writer := &closeCountingPacketWriter{}
74
316
  rpc := NewClientRPC(context.Background(), "service", "method")
package/srpc/rwc-conn.go CHANGED
@@ -7,6 +7,8 @@ import (
7
7
  "os"
8
8
  "sync"
9
9
  "time"
10
+
11
+ "github.com/aperturerobotics/starpc/internal/contextutil"
10
12
  )
11
13
 
12
14
  // connPktSize is the size of the buffers to use for packets for the RwcConn.
@@ -68,7 +70,7 @@ func NewRwcConn(
68
70
  laddr, raddr net.Addr,
69
71
  bufferPacketN int,
70
72
  ) *RwcConn {
71
- ctx, ctxCancel := context.WithCancel(ctx)
73
+ ctx, ctxCancel := contextutil.WithCancel(ctx)
72
74
  if bufferPacketN <= 0 {
73
75
  bufferPacketN = 10
74
76
  }
@@ -91,13 +91,13 @@ func (r *ServerRPC) HandleCallStart(pkt *CallStart) error {
91
91
  // invokeRPC invokes the RPC after CallStart is received.
92
92
  func (r *ServerRPC) invokeRPC(serviceID, methodID string) {
93
93
  // on the server side, the writer is closed by invokeRPC.
94
- strm := NewMsgStream(r.ctx, r, r.ctxCancel)
94
+ strm := NewMsgStream(r.ctx, r, r.cancelContext)
95
95
  ok, err := r.invoker.InvokeMethod(serviceID, methodID, strm)
96
96
  if err == nil && !ok {
97
97
  err = ErrUnimplemented
98
98
  }
99
+ r.beginLocalCompletion()
99
100
  outPkt := NewCallDataPacket(nil, false, true, err)
100
101
  _ = r.writer.WritePacket(outPkt)
101
- _ = r.writer.Close()
102
- r.ctxCancel()
102
+ r.finishLocalCompletion()
103
103
  }
package/srpc/server.go CHANGED
@@ -3,6 +3,8 @@ package srpc
3
3
  import (
4
4
  "context"
5
5
  "io"
6
+
7
+ "github.com/aperturerobotics/starpc/internal/contextutil"
6
8
  )
7
9
 
8
10
  // Server handles incoming RPC streams with a mux.
@@ -26,7 +28,7 @@ func (s *Server) GetInvoker() Invoker {
26
28
  // HandleStream handles an incoming stream and runs the read loop.
27
29
  // Uses length-prefixed packets.
28
30
  func (s *Server) HandleStream(ctx context.Context, rwc io.ReadWriteCloser) {
29
- subCtx, subCtxCancel := context.WithCancel(ctx)
31
+ subCtx, subCtxCancel := contextutil.WithCancel(ctx)
30
32
  defer subCtxCancel()
31
33
  prw := NewPacketReadWriter(rwc)
32
34
  serverRPC := NewServerRPC(subCtx, s.invoker, prw)
@@ -4,6 +4,8 @@ import (
4
4
  "context"
5
5
  "io"
6
6
  "sync"
7
+
8
+ "github.com/aperturerobotics/starpc/internal/contextutil"
7
9
  )
8
10
 
9
11
  // pipeStream implements an in-memory stream.
@@ -22,9 +24,9 @@ type pipeStream struct {
22
24
  // NewPipeStream constructs a new in-memory stream.
23
25
  func NewPipeStream(ctx context.Context) (Stream, Stream) {
24
26
  s1 := &pipeStream{dataCh: make(chan []byte, 5)}
25
- s1.ctx, s1.ctxCancel = context.WithCancel(ctx)
27
+ s1.ctx, s1.ctxCancel = contextutil.WithCancel(ctx)
26
28
  s2 := &pipeStream{other: s1, dataCh: make(chan []byte, 5)}
27
- s2.ctx, s2.ctxCancel = context.WithCancel(ctx)
29
+ s2.ctx, s2.ctxCancel = contextutil.WithCancel(ctx)
28
30
  s1.other = s2
29
31
  return s1, s2
30
32
  }