starpc 0.49.14 → 0.49.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,256 @@
1
+ import { describe, expect, it } from 'vitest'
2
+
3
+ import { Client } from './client.js'
4
+ import { CommonRPC } from './common-rpc.js'
5
+ import { Packet } from './rpcproto.pb.js'
6
+
7
+ describe('CommonRPC', () => {
8
+ it('backpressures call-data sources until outbound packets drain', async () => {
9
+ const rpc = new CommonRPC()
10
+ const yielded = { count: 0 }
11
+ const secondYield = deferred()
12
+ const thirdYield = deferred()
13
+
14
+ const writeDone = rpc.writeCallDataFromSource(
15
+ chunkSource(
16
+ yielded,
17
+ new Map([
18
+ [1, secondYield],
19
+ [2, thirdYield],
20
+ ]),
21
+ ),
22
+ )
23
+
24
+ await promiseWithTimeout(secondYield.promise, 'second call-data chunk')
25
+ expect(yielded.count).toBe(2)
26
+ await expectPending(thirdYield.promise, 'third call-data chunk')
27
+ expect(yielded.count).toBe(2)
28
+
29
+ const received = new Array<number>()
30
+ const readComplete = (async () => {
31
+ for await (const packet of rpc.source) {
32
+ const body = packet.body
33
+ if (body?.case !== 'callData') {
34
+ continue
35
+ }
36
+ if (body.value.complete) {
37
+ return
38
+ }
39
+ if (!body.value.data) {
40
+ throw new Error('call data packet missing data')
41
+ }
42
+ received.push(body.value.data[0])
43
+ }
44
+ throw new Error('outbound call data ended before completion')
45
+ })()
46
+
47
+ await promiseWithTimeout(
48
+ Promise.all([writeDone, readComplete]),
49
+ 'backpressured call data drain',
50
+ )
51
+ await rpc.close()
52
+
53
+ expect(yielded.count).toBe(32)
54
+ expect(received).toHaveLength(32)
55
+ expect(received[0]).toBe(0)
56
+ expect(received[31]).toBe(31)
57
+ })
58
+
59
+ it('wakes call-data writers when the rpc closes while waiting for drain', async () => {
60
+ const rpc = new CommonRPC()
61
+ const yielded = { count: 0 }
62
+ const secondYield = deferred()
63
+ const thirdYield = deferred()
64
+
65
+ const writeDone = rpc.writeCallDataFromSource(
66
+ chunkSource(
67
+ yielded,
68
+ new Map([
69
+ [1, secondYield],
70
+ [2, thirdYield],
71
+ ]),
72
+ ),
73
+ )
74
+
75
+ await promiseWithTimeout(secondYield.promise, 'second call-data chunk')
76
+ expect(yielded.count).toBe(2)
77
+ await expectPending(thirdYield.promise, 'third call-data chunk')
78
+ expect(yielded.count).toBe(2)
79
+
80
+ await rpc.close()
81
+ await expect(
82
+ promiseWithTimeout(writeDone, 'call data writer close'),
83
+ ).resolves.toBeUndefined()
84
+ expect(yielded.count).toBeLessThan(32)
85
+ })
86
+
87
+ it('wakes call-data writers when an error closes the rpc while waiting for drain', async () => {
88
+ const rpc = new CommonRPC()
89
+ const yielded = { count: 0 }
90
+ const secondYield = deferred()
91
+ const thirdYield = deferred()
92
+
93
+ const writeDone = rpc.writeCallDataFromSource(
94
+ chunkSource(
95
+ yielded,
96
+ new Map([
97
+ [1, secondYield],
98
+ [2, thirdYield],
99
+ ]),
100
+ ),
101
+ )
102
+
103
+ await promiseWithTimeout(secondYield.promise, 'second call-data chunk')
104
+ expect(yielded.count).toBe(2)
105
+ await expectPending(thirdYield.promise, 'third call-data chunk')
106
+ expect(yielded.count).toBe(2)
107
+
108
+ await rpc.close(new Error('boom'))
109
+ await expect(
110
+ promiseWithTimeout(writeDone, 'call data writer error close'),
111
+ ).resolves.toBeUndefined()
112
+ expect(yielded.count).toBeLessThan(32)
113
+ })
114
+
115
+ it('does not backpressure call-cancel packets behind queued call data', async () => {
116
+ const rpc = new CommonRPC()
117
+ const yielded = { count: 0 }
118
+ const secondYield = deferred()
119
+ const thirdYield = deferred()
120
+ const writeDone = rpc.writeCallDataFromSource(
121
+ chunkSource(
122
+ yielded,
123
+ new Map([
124
+ [1, secondYield],
125
+ [2, thirdYield],
126
+ ]),
127
+ ),
128
+ )
129
+
130
+ await promiseWithTimeout(secondYield.promise, 'second call-data chunk')
131
+ expect(yielded.count).toBe(2)
132
+ await expectPending(thirdYield.promise, 'third call-data chunk')
133
+ expect(yielded.count).toBe(2)
134
+
135
+ const cancelDone = rpc.writeCallCancel()
136
+ await rpc.close(new Error('abort'))
137
+
138
+ await expect(
139
+ promiseWithTimeout(cancelDone, 'call cancel write'),
140
+ ).resolves.toBeUndefined()
141
+ await expect(
142
+ promiseWithTimeout(writeDone, 'blocked call data close'),
143
+ ).resolves.toBeUndefined()
144
+ })
145
+
146
+ it('backpressures client-stream sources through the Client request path', async () => {
147
+ const yielded = { count: 0 }
148
+ const firstYield = deferred()
149
+ const secondYield = deferred()
150
+ const sinkConsumed = { count: 0 }
151
+ const sinkGate = deferred()
152
+ const responseGate = deferred()
153
+ const response = new Uint8Array([7])
154
+ const client = new Client(async () => ({
155
+ source: (async function* () {
156
+ await responseGate.promise
157
+ yield Packet.toBinary({
158
+ body: {
159
+ case: 'callData',
160
+ value: {
161
+ data: response,
162
+ dataIsZero: false,
163
+ complete: false,
164
+ error: '',
165
+ },
166
+ },
167
+ })
168
+ })(),
169
+ sink: async (source) => {
170
+ await sinkGate.promise
171
+ for await (const _packet of source) {
172
+ sinkConsumed.count++
173
+ }
174
+ },
175
+ }))
176
+
177
+ const request = client.clientStreamingRequest(
178
+ 'test.Service',
179
+ 'Upload',
180
+ chunkSource(
181
+ yielded,
182
+ new Map([
183
+ [0, firstYield],
184
+ [1, secondYield],
185
+ ]),
186
+ ),
187
+ )
188
+
189
+ await promiseWithTimeout(firstYield.promise, 'first upload chunk')
190
+ expect(yielded.count).toBe(1)
191
+ await expectPending(secondYield.promise, 'second upload chunk')
192
+ expect(yielded.count).toBe(1)
193
+ expect(sinkConsumed.count).toBe(0)
194
+ await expectPending(request, 'client-stream request')
195
+
196
+ sinkGate.resolve()
197
+ responseGate.resolve()
198
+ await expect(
199
+ promiseWithTimeout(request, 'client-stream response'),
200
+ ).resolves.toEqual(response)
201
+ })
202
+ })
203
+
204
+ async function* chunkSource(
205
+ yielded: { count: number },
206
+ yieldMarks?: Map<number, Deferred>,
207
+ ) {
208
+ for (const i of Array.from({ length: 32 }, (_, index) => index)) {
209
+ yielded.count++
210
+ yieldMarks?.get(i)?.resolve()
211
+ yield new Uint8Array([i])
212
+ }
213
+ }
214
+
215
+ function promiseWithTimeout<T>(promise: Promise<T>, label: string): Promise<T> {
216
+ return Promise.race([
217
+ promise,
218
+ new Promise<T>((_, reject) => {
219
+ setTimeout(() => reject(new Error(`timed out waiting for ${label}`)), 500)
220
+ }),
221
+ ])
222
+ }
223
+
224
+ async function expectPending<T>(promise: Promise<T>, label: string) {
225
+ await expect(
226
+ Promise.race([
227
+ promise.then(() => 'settled'),
228
+ new Promise<'pending'>((resolve) => {
229
+ setTimeout(() => resolve('pending'), 50)
230
+ }),
231
+ ]),
232
+ `${label} should not be requested while outbound packets are blocked`,
233
+ ).resolves.toBe('pending')
234
+ }
235
+
236
+ type Deferred = ReturnType<typeof deferred>
237
+
238
+ function deferred() {
239
+ const callbacks: {
240
+ resolve?: () => void
241
+ reject?: (reason?: unknown) => void
242
+ } = {}
243
+ const promise = new Promise<void>((resolve, reject) => {
244
+ callbacks.resolve = resolve
245
+ callbacks.reject = reject
246
+ })
247
+ return {
248
+ promise,
249
+ resolve() {
250
+ callbacks.resolve?.()
251
+ },
252
+ reject(reason?: unknown) {
253
+ callbacks.reject?.(reason)
254
+ },
255
+ }
256
+ }
@@ -1,11 +1,13 @@
1
1
  import type { Sink, Source } from 'it-stream-types'
2
- import { pushable } from 'it-pushable'
2
+ import { pushable, type Pushable } from 'it-pushable'
3
3
  import { CompleteMessage } from '@aptre/protobuf-es-lite'
4
4
 
5
5
  import type { CallData, CallStart } from './rpcproto.pb.js'
6
6
  import { Packet } from './rpcproto.pb.js'
7
7
  import { ERR_RPC_ABORT, RemoteRPCError } from './errors.js'
8
8
 
9
+ const maxBufferedOutgoingPackets = 1
10
+
9
11
  // CommonRPC is common logic between server and client RPCs.
10
12
  export class CommonRPC {
11
13
  // sink is the data sink for incoming messages.
@@ -16,7 +18,7 @@ export class CommonRPC {
16
18
  public readonly rpcDataSource: AsyncIterable<Uint8Array>
17
19
 
18
20
  // _source is used to write to the source.
19
- private readonly _source = pushable<Packet>({
21
+ private readonly _source: Pushable<Packet> = pushable<Packet>({
20
22
  objectMode: true,
21
23
  })
22
24
 
@@ -32,6 +34,8 @@ export class CommonRPC {
32
34
 
33
35
  // closed indicates this rpc has been closed already.
34
36
  private closed?: true | Error
37
+ // writeDrainAbort wakes writers waiting for outbound stream drain on close.
38
+ private readonly writeDrainAbort = new AbortController()
35
39
 
36
40
  constructor() {
37
41
  this.sink = this._createSink()
@@ -49,6 +53,16 @@ export class CommonRPC {
49
53
  data?: Uint8Array,
50
54
  complete?: boolean,
51
55
  error?: string,
56
+ ) {
57
+ await this.writeCallDataPacket(data, complete, error)
58
+ }
59
+
60
+ // writeCallDataPacket writes a call-data packet with optional drain control.
61
+ private async writeCallDataPacket(
62
+ data?: Uint8Array,
63
+ complete?: boolean,
64
+ error?: string,
65
+ writeOptions?: WritePacketOptions,
52
66
  ) {
53
67
  const callData: CompleteMessage<CallData> = {
54
68
  data: data || new Uint8Array(0),
@@ -56,22 +70,28 @@ export class CommonRPC {
56
70
  complete: complete || false,
57
71
  error: error || '',
58
72
  }
59
- await this.writePacket({
60
- body: {
61
- case: 'callData',
62
- value: callData,
73
+ await this.writePacket(
74
+ {
75
+ body: {
76
+ case: 'callData',
77
+ value: callData,
78
+ },
63
79
  },
64
- })
80
+ writeOptions,
81
+ )
65
82
  }
66
83
 
67
84
  // writeCallCancel writes the call cancel packet.
68
85
  public async writeCallCancel() {
69
- await this.writePacket({
70
- body: {
71
- case: 'callCancel',
72
- value: true,
86
+ await this.writePacket(
87
+ {
88
+ body: {
89
+ case: 'callCancel',
90
+ value: true,
91
+ },
73
92
  },
74
- })
93
+ { waitForDrain: false },
94
+ )
75
95
  }
76
96
 
77
97
  // writeCallDataFromSource writes all call data from the iterable.
@@ -87,8 +107,25 @@ export class CommonRPC {
87
107
  }
88
108
 
89
109
  // writePacket writes a packet to the stream.
90
- protected async writePacket(packet: Packet) {
110
+ protected async writePacket(packet: Packet, options?: WritePacketOptions) {
111
+ if (this.closed && !options?.allowClosed) {
112
+ throw new Error(ERR_RPC_ABORT)
113
+ }
91
114
  this._source.push(packet)
115
+ if (
116
+ options?.waitForDrain === false ||
117
+ this._source.readableLength <= maxBufferedOutgoingPackets
118
+ ) {
119
+ return
120
+ }
121
+ try {
122
+ await this._source.onEmpty({ signal: this.writeDrainAbort.signal })
123
+ } catch (err) {
124
+ if (this.closed) {
125
+ throw new Error(ERR_RPC_ABORT, { cause: err })
126
+ }
127
+ throw err
128
+ }
92
129
  }
93
130
 
94
131
  // handleMessage handles an incoming encoded Packet.
@@ -179,8 +216,12 @@ export class CommonRPC {
179
216
  this.closed = err ?? true
180
217
  // note: this does nothing if _source is already ended.
181
218
  if (err && err.message) {
182
- await this.writeCallData(undefined, true, err.message)
219
+ await this.writeCallDataPacket(undefined, true, err.message, {
220
+ allowClosed: true,
221
+ waitForDrain: false,
222
+ })
183
223
  }
224
+ this.writeDrainAbort.abort()
184
225
  this._source.end()
185
226
  this._rpcDataSource.end(err)
186
227
  }
@@ -206,3 +247,8 @@ export class CommonRPC {
206
247
  }
207
248
  }
208
249
  }
250
+
251
+ interface WritePacketOptions {
252
+ allowClosed?: boolean
253
+ waitForDrain?: boolean
254
+ }
@@ -4,8 +4,10 @@ import (
4
4
  "bytes"
5
5
  "context"
6
6
  "io"
7
+ "sync"
7
8
  "sync/atomic"
8
9
  "testing"
10
+ "time"
9
11
  )
10
12
 
11
13
  type closeCountingPacketWriter struct {
@@ -16,6 +18,19 @@ type closeCallbackPacketWriter struct {
16
18
  closeFn func()
17
19
  }
18
20
 
21
+ type packetRecordingWriter struct {
22
+ packets chan *Packet
23
+ closed chan struct{}
24
+ once sync.Once
25
+ }
26
+
27
+ type blockingPacketWriter struct {
28
+ packets chan *Packet
29
+ releaseWrite chan struct{}
30
+ closed chan struct{}
31
+ once sync.Once
32
+ }
33
+
19
34
  func (w *closeCountingPacketWriter) WritePacket(*Packet) error {
20
35
  return nil
21
36
  }
@@ -36,6 +51,46 @@ func (w *closeCallbackPacketWriter) Close() error {
36
51
  return nil
37
52
  }
38
53
 
54
+ func newPacketRecordingWriter() *packetRecordingWriter {
55
+ return &packetRecordingWriter{
56
+ packets: make(chan *Packet, 1),
57
+ closed: make(chan struct{}),
58
+ }
59
+ }
60
+
61
+ func (w *packetRecordingWriter) WritePacket(pkt *Packet) error {
62
+ w.packets <- pkt
63
+ return nil
64
+ }
65
+
66
+ func (w *packetRecordingWriter) Close() error {
67
+ w.once.Do(func() {
68
+ close(w.closed)
69
+ })
70
+ return nil
71
+ }
72
+
73
+ func newBlockingPacketWriter() *blockingPacketWriter {
74
+ return &blockingPacketWriter{
75
+ packets: make(chan *Packet, 1),
76
+ releaseWrite: make(chan struct{}),
77
+ closed: make(chan struct{}),
78
+ }
79
+ }
80
+
81
+ func (w *blockingPacketWriter) WritePacket(pkt *Packet) error {
82
+ w.packets <- pkt
83
+ <-w.releaseWrite
84
+ return nil
85
+ }
86
+
87
+ func (w *blockingPacketWriter) Close() error {
88
+ w.once.Do(func() {
89
+ close(w.closed)
90
+ })
91
+ return nil
92
+ }
93
+
39
94
  func TestCommonRPCHandleStreamCloseIdempotent(t *testing.T) {
40
95
  writer := &closeCountingPacketWriter{}
41
96
  rpc := NewServerRPC(context.Background(), InvokerFunc(nil), writer)
@@ -48,6 +103,23 @@ func TestCommonRPCHandleStreamCloseIdempotent(t *testing.T) {
48
103
  }
49
104
  }
50
105
 
106
+ func TestCommonRPCCancelContextIdempotent(t *testing.T) {
107
+ var calls atomic.Int32
108
+ rpc := &commonRPC{
109
+ ctx: context.Background(),
110
+ ctxCancel: func() {
111
+ calls.Add(1)
112
+ },
113
+ }
114
+
115
+ rpc.cancelContext()
116
+ rpc.cancelContext()
117
+
118
+ if got := calls.Load(); got != 1 {
119
+ t.Fatalf("expected context cancel once, got %d", got)
120
+ }
121
+ }
122
+
51
123
  func TestCommonRPCHandleStreamCloseClosesWriterOutsideBroadcastLock(t *testing.T) {
52
124
  var rpc *ServerRPC
53
125
  writerClosedOutsideLock := false
@@ -69,6 +141,214 @@ func TestCommonRPCHandleStreamCloseClosesWriterOutsideBroadcastLock(t *testing.T
69
141
  }
70
142
  }
71
143
 
144
+ func TestServerRPCWaitReturnsAfterLocalInvokeCompletion(t *testing.T) {
145
+ writer := newPacketRecordingWriter()
146
+ streamCtxCh := make(chan context.Context, 1)
147
+ invokeErrCh := make(chan string, 1)
148
+ rpc := NewServerRPC(context.Background(), InvokerFunc(func(serviceID, methodID string, strm Stream) (bool, error) {
149
+ if serviceID != "service" || methodID != "method" {
150
+ invokeErrCh <- serviceID + "/" + methodID
151
+ }
152
+ streamCtxCh <- strm.Context()
153
+ return true, nil
154
+ }), writer)
155
+
156
+ if err := rpc.HandleCallStart(NewCallStartPacket("service", "method", nil, false).GetCallStart()); err != nil {
157
+ t.Fatalf("handle call start: %v", err)
158
+ }
159
+
160
+ waitCtx, waitCancel := context.WithTimeout(context.Background(), time.Second)
161
+ defer waitCancel()
162
+ if err := rpc.Wait(waitCtx); err != nil {
163
+ t.Fatalf("wait: %v", err)
164
+ }
165
+
166
+ select {
167
+ case got := <-invokeErrCh:
168
+ t.Fatalf("unexpected invoke target: %s", got)
169
+ default:
170
+ }
171
+
172
+ var streamCtx context.Context
173
+ select {
174
+ case streamCtx = <-streamCtxCh:
175
+ case <-time.After(time.Second):
176
+ t.Fatal("invoker did not receive stream context")
177
+ }
178
+ if err := streamCtx.Err(); err != nil {
179
+ t.Fatalf("normal invoke completion canceled stream context: %v", err)
180
+ }
181
+
182
+ select {
183
+ case pkt := <-writer.packets:
184
+ callData := pkt.GetCallData()
185
+ if callData == nil || !callData.GetComplete() || callData.GetError() != "" {
186
+ t.Fatalf("expected successful completion packet, got %#v", pkt.GetBody())
187
+ }
188
+ default:
189
+ t.Fatal("expected completion packet")
190
+ }
191
+
192
+ select {
193
+ case <-writer.closed:
194
+ default:
195
+ t.Fatal("expected writer closed")
196
+ }
197
+ }
198
+
199
+ func TestServerRPCRemoteCloseAfterLocalCompletionDoesNotCancelStreamContext(t *testing.T) {
200
+ writer := newPacketRecordingWriter()
201
+ streamCtxCh := make(chan context.Context, 1)
202
+ rpc := NewServerRPC(context.Background(), InvokerFunc(func(serviceID, methodID string, strm Stream) (bool, error) {
203
+ streamCtxCh <- strm.Context()
204
+ return true, nil
205
+ }), writer)
206
+
207
+ if err := rpc.HandleCallStart(NewCallStartPacket("service", "method", nil, false).GetCallStart()); err != nil {
208
+ t.Fatalf("handle call start: %v", err)
209
+ }
210
+
211
+ var streamCtx context.Context
212
+ select {
213
+ case streamCtx = <-streamCtxCh:
214
+ case <-time.After(time.Second):
215
+ t.Fatal("invoker did not receive stream context")
216
+ }
217
+ select {
218
+ case <-writer.closed:
219
+ case <-time.After(time.Second):
220
+ t.Fatal("writer did not close")
221
+ }
222
+
223
+ rpc.HandleStreamClose(nil)
224
+
225
+ if err := streamCtx.Err(); err != nil {
226
+ t.Fatalf("normal remote close after local completion canceled stream context: %v", err)
227
+ }
228
+ waitCtx, waitCancel := context.WithTimeout(context.Background(), time.Second)
229
+ defer waitCancel()
230
+ if err := rpc.Wait(waitCtx); err != nil {
231
+ t.Fatalf("wait: %v", err)
232
+ }
233
+ }
234
+
235
+ func TestServerRPCRemoteCloseDuringLocalCompletionDoesNotCancelStreamContext(t *testing.T) {
236
+ writer := newBlockingPacketWriter()
237
+ streamCtxCh := make(chan context.Context, 1)
238
+ rpc := NewServerRPC(context.Background(), InvokerFunc(func(serviceID, methodID string, strm Stream) (bool, error) {
239
+ streamCtxCh <- strm.Context()
240
+ return true, nil
241
+ }), writer)
242
+
243
+ if err := rpc.HandleCallStart(NewCallStartPacket("service", "method", nil, false).GetCallStart()); err != nil {
244
+ t.Fatalf("handle call start: %v", err)
245
+ }
246
+
247
+ var streamCtx context.Context
248
+ select {
249
+ case streamCtx = <-streamCtxCh:
250
+ case <-time.After(time.Second):
251
+ t.Fatal("invoker did not receive stream context")
252
+ }
253
+ select {
254
+ case <-writer.packets:
255
+ case <-time.After(time.Second):
256
+ t.Fatal("terminal packet write did not start")
257
+ }
258
+
259
+ rpc.HandleStreamClose(nil)
260
+
261
+ if err := streamCtx.Err(); err != nil {
262
+ t.Fatalf("normal remote close during local completion canceled stream context: %v", err)
263
+ }
264
+ close(writer.releaseWrite)
265
+ select {
266
+ case <-writer.closed:
267
+ case <-time.After(time.Second):
268
+ t.Fatal("writer did not close")
269
+ }
270
+ waitCtx, waitCancel := context.WithTimeout(context.Background(), time.Second)
271
+ defer waitCancel()
272
+ if err := rpc.Wait(waitCtx); err != nil {
273
+ t.Fatalf("wait: %v", err)
274
+ }
275
+ }
276
+
277
+ func TestServerRPCRemoteErrorDuringLocalCompletionCancelsStreamContext(t *testing.T) {
278
+ writer := newBlockingPacketWriter()
279
+ streamCtxCh := make(chan context.Context, 1)
280
+ rpc := NewServerRPC(context.Background(), InvokerFunc(func(serviceID, methodID string, strm Stream) (bool, error) {
281
+ streamCtxCh <- strm.Context()
282
+ return true, nil
283
+ }), writer)
284
+
285
+ if err := rpc.HandleCallStart(NewCallStartPacket("service", "method", nil, false).GetCallStart()); err != nil {
286
+ t.Fatalf("handle call start: %v", err)
287
+ }
288
+
289
+ var streamCtx context.Context
290
+ select {
291
+ case streamCtx = <-streamCtxCh:
292
+ case <-time.After(time.Second):
293
+ t.Fatal("invoker did not receive stream context")
294
+ }
295
+ select {
296
+ case <-writer.packets:
297
+ case <-time.After(time.Second):
298
+ t.Fatal("terminal packet write did not start")
299
+ }
300
+
301
+ rpc.HandleStreamClose(context.Canceled)
302
+
303
+ if err := streamCtx.Err(); err != context.Canceled {
304
+ t.Fatalf("remote error during local completion context error = %v want %v", err, context.Canceled)
305
+ }
306
+ close(writer.releaseWrite)
307
+ waitCtx, waitCancel := context.WithTimeout(context.Background(), time.Second)
308
+ defer waitCancel()
309
+ if err := rpc.Wait(waitCtx); err != context.Canceled {
310
+ t.Fatalf("wait = %v want %v", err, context.Canceled)
311
+ }
312
+ }
313
+
314
+ func TestClientRPCCloseAfterRemoteCompleteClosesWriter(t *testing.T) {
315
+ writer := &closeCountingPacketWriter{}
316
+ rpc := NewClientRPC(context.Background(), "service", "method")
317
+ if err := rpc.Start(writer, false, nil); err != nil {
318
+ t.Fatalf("start: %v", err)
319
+ }
320
+ if err := rpc.HandleCallData(NewCallDataPacket([]byte("ok"), false, true, nil).GetCallData()); err != nil {
321
+ t.Fatalf("handle call data: %v", err)
322
+ }
323
+
324
+ rpc.Close()
325
+
326
+ if got := writer.closed.Load(); got != 1 {
327
+ t.Fatalf("expected writer closed once, got %d", got)
328
+ }
329
+ }
330
+
331
+ func TestClientRPCCloseClosesWriterOutsideBroadcastLock(t *testing.T) {
332
+ var rpc *ClientRPC
333
+ writerClosedOutsideLock := false
334
+ writer := &closeCallbackPacketWriter{
335
+ closeFn: func() {
336
+ ok := rpc.bcast.TryHoldLock(func(func(), func() <-chan struct{}) {})
337
+ writerClosedOutsideLock = ok
338
+ },
339
+ }
340
+ rpc = NewClientRPC(context.Background(), "service", "method")
341
+ if err := rpc.Start(writer, false, nil); err != nil {
342
+ t.Fatalf("start: %v", err)
343
+ }
344
+
345
+ rpc.Close()
346
+
347
+ if !writerClosedOutsideLock {
348
+ t.Fatal("expected writer close outside broadcast lock")
349
+ }
350
+ }
351
+
72
352
  func TestCommonRPCReadOneQueuedDoesNotAllocate(t *testing.T) {
73
353
  msg := []byte("message")
74
354
  queue := make([][]byte, 1)