starpc 0.49.10 → 0.49.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/srpc/channel.ts CHANGED
@@ -13,6 +13,8 @@ export interface ChannelStreamMessage<T> {
13
13
  opened?: true
14
14
  // closed indicates the stream is closed.
15
15
  closed?: true
16
+ // teardown indicates the whole channel should be torn down.
17
+ teardown?: true
16
18
  // error indicates the stream has an error.
17
19
  error?: Error
18
20
  // data is any message data.
@@ -82,6 +84,10 @@ export class ChannelStream<T = Uint8Array> implements Duplex<
82
84
  private idleWatchdog?: Watchdog
83
85
  // closed indicates the local channel has been torn down.
84
86
  private closed = false
87
+ // localWriteClosed indicates the local sink finished normally.
88
+ private localWriteClosed = false
89
+ // remoteWriteClosed indicates the remote sink finished normally.
90
+ private remoteWriteClosed = false
85
91
 
86
92
  // isAcked checks if the stream is acknowledged by the remote.
87
93
  public get isAcked() {
@@ -207,12 +213,14 @@ export class ChannelStream<T = Uint8Array> implements Duplex<
207
213
  }
208
214
  if (notifyRemote) {
209
215
  try {
210
- this.postMessage({ closed: true, error })
216
+ this.postMessage({ closed: true, error, teardown: true })
211
217
  } catch {
212
218
  // Ignore close races while tearing down the local channel.
213
219
  }
214
220
  }
215
221
  this.closed = true
222
+ this.localWriteClosed = true
223
+ this.remoteWriteClosed = true
216
224
  // close channels
217
225
  if (this.channel instanceof MessagePort) {
218
226
  this.channel.onmessage = null
@@ -241,6 +249,29 @@ export class ChannelStream<T = Uint8Array> implements Duplex<
241
249
  this._source.end(error)
242
250
  }
243
251
 
252
+ // finishIfBothDirectionsClosed tears down after a clean bidirectional EOF.
253
+ private finishIfBothDirectionsClosed() {
254
+ if (this.localWriteClosed && this.remoteWriteClosed) {
255
+ this.finish(undefined, false)
256
+ }
257
+ }
258
+
259
+ // clearKeepAlive stops local write keep-alives after the write side closes.
260
+ private clearKeepAlive() {
261
+ if (this.keepAlive) {
262
+ this.keepAlive.clear()
263
+ delete this.keepAlive
264
+ }
265
+ }
266
+
267
+ // clearIdleWatchdog stops receive watchdogs after the peer write side closes.
268
+ private clearIdleWatchdog() {
269
+ if (this.idleWatchdog) {
270
+ this.idleWatchdog.clear()
271
+ delete this.idleWatchdog
272
+ }
273
+ }
274
+
244
275
  // close closes the broadcast channels.
245
276
  public close(error?: Error) {
246
277
  this.finish(error, true)
@@ -297,9 +328,15 @@ export class ChannelStream<T = Uint8Array> implements Duplex<
297
328
  for await (const msg of source) {
298
329
  this.postMessage({ data: msg })
299
330
  }
331
+ this.localWriteClosed = true
332
+ this.clearKeepAlive()
300
333
  this.postMessage({ closed: true })
334
+ this.finishIfBothDirectionsClosed()
301
335
  } catch (error) {
302
- this.postMessage({ closed: true, error: error as Error })
336
+ this.finish(
337
+ error instanceof Error ? error : new Error(String(error)),
338
+ true,
339
+ )
303
340
  }
304
341
  }
305
342
  }
@@ -316,16 +353,19 @@ export class ChannelStream<T = Uint8Array> implements Duplex<
316
353
  if (msg.opened) {
317
354
  this.onRemoteOpened()
318
355
  }
319
- const { data, closed, error: err } = msg
320
- if (data) {
356
+ const { data, closed, error: err, teardown } = msg
357
+ if (data !== undefined) {
321
358
  this._source.push(data)
322
359
  }
323
- if (err) {
360
+ if (err || teardown) {
324
361
  this.finish(err, false)
325
362
  return
326
363
  }
327
364
  if (closed) {
328
- this.finish(undefined, false)
365
+ this.remoteWriteClosed = true
366
+ this.clearIdleWatchdog()
367
+ this._source.end()
368
+ this.finishIfBothDirectionsClosed()
329
369
  }
330
370
  }
331
371
  }
@@ -38,22 +38,22 @@ func (r *ClientRPC) Start(writer PacketWriter, writeFirstMsg bool, firstMsg []by
38
38
 
39
39
  var firstMsgEmpty bool
40
40
  var err error
41
- r.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
42
- r.writer = writer
41
+ locked := r.bcast.Lock()
42
+ r.writer = writer
43
43
 
44
- if writeFirstMsg {
45
- firstMsgEmpty = len(firstMsg) == 0
46
- }
44
+ if writeFirstMsg {
45
+ firstMsgEmpty = len(firstMsg) == 0
46
+ }
47
47
 
48
- pkt := NewCallStartPacket(r.service, r.method, firstMsg, firstMsgEmpty)
49
- err = writer.WritePacket(pkt)
50
- if err != nil {
51
- r.ctxCancel()
52
- _ = writer.Close()
53
- }
48
+ pkt := NewCallStartPacket(r.service, r.method, firstMsg, firstMsgEmpty)
49
+ err = writer.WritePacket(pkt)
50
+ if err != nil {
51
+ r.ctxCancel()
52
+ _ = writer.Close()
53
+ }
54
54
 
55
- broadcast()
56
- })
55
+ locked.Broadcast()
56
+ locked.Unlock()
57
57
 
58
58
  return err
59
59
  }
@@ -69,14 +69,14 @@ func (r *ClientRPC) HandlePacketData(data []byte) error {
69
69
 
70
70
  // HandleStreamClose handles the stream closing optionally w/ an error.
71
71
  func (r *ClientRPC) HandleStreamClose(closeErr error) {
72
- r.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
73
- if closeErr != nil && r.remoteErr == nil {
74
- r.remoteErr = closeErr
75
- }
76
- r.dataClosed = true
77
- r.ctxCancel()
78
- broadcast()
79
- })
72
+ locked := r.bcast.Lock()
73
+ if closeErr != nil && r.remoteErr == nil {
74
+ r.remoteErr = closeErr
75
+ }
76
+ r.dataClosed = true
77
+ r.ctxCancel()
78
+ locked.Broadcast()
79
+ locked.Unlock()
80
80
  }
81
81
 
82
82
  // HandlePacket handles an incoming parsed message packet.
@@ -108,11 +108,11 @@ func (r *ClientRPC) HandleCallStart(pkt *CallStart) error {
108
108
 
109
109
  // Close releases any resources held by the ClientRPC.
110
110
  func (r *ClientRPC) Close() {
111
- r.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
112
- // call did not start yet if writer is nil.
113
- if r.writer != nil {
114
- _ = r.WriteCallCancel()
115
- r.closeLocked(broadcast)
116
- }
117
- })
111
+ locked := r.bcast.Lock()
112
+ // call did not start yet if writer is nil.
113
+ if r.writer != nil {
114
+ _ = r.WriteCallCancel()
115
+ r.closeLocked(&locked)
116
+ }
117
+ locked.Unlock()
118
118
  }
@@ -52,10 +52,12 @@ func (c *commonRPC) Wait(ctx context.Context) error {
52
52
  var err error
53
53
  var waitCh <-chan struct{}
54
54
  var rpcCtx context.Context
55
- c.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
56
- rpcCtx, err = c.ctx, c.remoteErr
57
- waitCh = getWaitCh()
58
- })
55
+ locked := c.bcast.Lock()
56
+ rpcCtx, err = c.ctx, c.remoteErr
57
+ if err == nil && rpcCtx.Err() == nil {
58
+ waitCh = locked.WaitCh()
59
+ }
60
+ locked.Unlock()
59
61
 
60
62
  if err != nil {
61
63
  return err
@@ -78,43 +80,37 @@ func (c *commonRPC) Wait(ctx context.Context) error {
78
80
  //
79
81
  // returns io.EOF if the stream ended without a packet.
80
82
  func (c *commonRPC) ReadOne() ([]byte, error) {
81
- var hasMsg bool
82
- var msg []byte
83
- var err error
84
83
  var ctxDone bool
85
84
  for {
86
85
  var waitCh <-chan struct{}
87
- c.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
88
- if ctxDone && !c.dataClosed {
89
- // context must have been canceled locally
90
- c.closeLocked(broadcast)
91
- err = context.Canceled
92
- return
93
- }
94
-
95
- if len(c.dataQueue) != 0 {
96
- msg = c.dataQueue[0]
97
- hasMsg = true
98
- c.dataQueue[0] = nil
99
- c.dataQueue = c.dataQueue[1:]
100
- } else if c.dataClosed || c.remoteErr != nil {
101
- err = c.remoteErr
102
- if err == nil {
103
- err = io.EOF
104
- }
105
- }
106
-
107
- waitCh = getWaitCh()
108
- })
86
+ locked := c.bcast.Lock()
87
+ if ctxDone && !c.dataClosed {
88
+ // context must have been canceled locally
89
+ c.closeLocked(&locked)
90
+ locked.Unlock()
91
+ return nil, context.Canceled
92
+ }
109
93
 
110
- if hasMsg {
94
+ if len(c.dataQueue) != 0 {
95
+ msg := c.dataQueue[0]
96
+ c.dataQueue[0] = nil
97
+ c.dataQueue = c.dataQueue[1:]
98
+ locked.Unlock()
111
99
  return msg, nil
112
100
  }
113
101
 
114
- if err != nil {
102
+ if c.dataClosed || c.remoteErr != nil {
103
+ err := c.remoteErr
104
+ if err == nil {
105
+ err = io.EOF
106
+ }
107
+ locked.Unlock()
115
108
  return nil, err
116
109
  }
117
110
 
111
+ waitCh = locked.WaitCh()
112
+ locked.Unlock()
113
+
118
114
  select {
119
115
  case <-c.ctx.Done():
120
116
  ctxDone = true
@@ -147,18 +143,19 @@ func (c *commonRPC) WriteCallData(data []byte, dataIsZero, complete bool, err er
147
143
  // HandleStreamClose handles the incoming stream closing w/ optional error.
148
144
  func (c *commonRPC) HandleStreamClose(closeErr error) {
149
145
  var writer PacketWriter
150
- c.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
151
- if c.dataClosed {
152
- return
153
- }
154
- if closeErr != nil && c.remoteErr == nil {
155
- c.remoteErr = closeErr
156
- }
157
- c.dataClosed = true
158
- c.ctxCancel()
159
- writer = c.writer
160
- broadcast()
161
- })
146
+ locked := c.bcast.Lock()
147
+ if c.dataClosed {
148
+ locked.Unlock()
149
+ return
150
+ }
151
+ if closeErr != nil && c.remoteErr == nil {
152
+ c.remoteErr = closeErr
153
+ }
154
+ c.dataClosed = true
155
+ c.ctxCancel()
156
+ writer = c.writer
157
+ locked.Broadcast()
158
+ locked.Unlock()
162
159
  if writer != nil {
163
160
  _ = writer.Close()
164
161
  }
@@ -173,34 +170,33 @@ func (c *commonRPC) HandleCallCancel() error {
173
170
  // HandleCallData handles the call data packet.
174
171
  func (c *commonRPC) HandleCallData(pkt *CallData) error {
175
172
  var err error
176
- c.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
177
- if c.dataClosed {
178
- // If the packet is just indicating the call is complete, ignore it.
179
- if pkt.GetComplete() {
180
- return
181
- }
182
-
173
+ locked := c.bcast.Lock()
174
+ if c.dataClosed {
175
+ // If the packet is just indicating the call is complete, ignore it.
176
+ if !pkt.GetComplete() {
183
177
  // Otherwise, return ErrCompleted (unexpected packet).
184
178
  err = ErrCompleted
185
- return
186
179
  }
180
+ locked.Unlock()
181
+ return err
182
+ }
187
183
 
188
- if data := pkt.GetData(); len(data) != 0 || pkt.GetDataIsZero() {
189
- c.dataQueue = append(c.dataQueue, data)
190
- }
184
+ if data := pkt.GetData(); len(data) != 0 || pkt.GetDataIsZero() {
185
+ c.dataQueue = append(c.dataQueue, data)
186
+ }
191
187
 
192
- complete := pkt.GetComplete()
193
- if err := pkt.GetError(); len(err) != 0 {
194
- complete = true
195
- c.remoteErr = errors.New(err)
196
- }
188
+ complete := pkt.GetComplete()
189
+ if pktErr := pkt.GetError(); len(pktErr) != 0 {
190
+ complete = true
191
+ c.remoteErr = errors.New(pktErr)
192
+ }
197
193
 
198
- if complete {
199
- c.dataClosed = true
200
- }
194
+ if complete {
195
+ c.dataClosed = true
196
+ }
201
197
 
202
- broadcast()
203
- })
198
+ locked.Broadcast()
199
+ locked.Unlock()
204
200
 
205
201
  return err
206
202
  }
@@ -216,7 +212,7 @@ func (c *commonRPC) WriteCallCancel() error {
216
212
  }
217
213
 
218
214
  // closeLocked releases resources held by the RPC.
219
- func (c *commonRPC) closeLocked(broadcast func()) {
215
+ func (c *commonRPC) closeLocked(locked *broadcast.Locked) {
220
216
  if c.dataClosed {
221
217
  return
222
218
  }
@@ -226,6 +222,6 @@ func (c *commonRPC) closeLocked(broadcast func()) {
226
222
  c.remoteErr = context.Canceled
227
223
  }
228
224
  _ = c.writer.Close()
229
- broadcast()
225
+ locked.Broadcast()
230
226
  c.ctxCancel()
231
227
  }
@@ -1,6 +1,7 @@
1
1
  package srpc
2
2
 
3
3
  import (
4
+ "bytes"
4
5
  "context"
5
6
  "io"
6
7
  "sync/atomic"
@@ -52,7 +53,11 @@ func TestCommonRPCHandleStreamCloseClosesWriterOutsideBroadcastLock(t *testing.T
52
53
  writerClosedOutsideLock := false
53
54
  writer := &closeCallbackPacketWriter{
54
55
  closeFn: func() {
55
- writerClosedOutsideLock = rpc.bcast.TryHoldLock(func(func(), func() <-chan struct{}) {})
56
+ locked, ok := rpc.bcast.TryLock()
57
+ if ok {
58
+ locked.Unlock()
59
+ }
60
+ writerClosedOutsideLock = ok
56
61
  },
57
62
  }
58
63
  rpc = NewServerRPC(context.Background(), InvokerFunc(nil), writer)
@@ -63,3 +68,26 @@ func TestCommonRPCHandleStreamCloseClosesWriterOutsideBroadcastLock(t *testing.T
63
68
  t.Fatal("expected writer close outside broadcast lock")
64
69
  }
65
70
  }
71
+
72
+ func TestCommonRPCReadOneQueuedDoesNotAllocate(t *testing.T) {
73
+ msg := []byte("message")
74
+ queue := make([][]byte, 1)
75
+ rpc := NewServerRPC(context.Background(), InvokerFunc(nil), &closeCountingPacketWriter{})
76
+
77
+ allocs := testing.AllocsPerRun(1000, func() {
78
+ queue[0] = msg
79
+ rpc.dataQueue = queue
80
+
81
+ got, err := rpc.ReadOne()
82
+ if err != nil {
83
+ t.Fatalf("read one: %v", err)
84
+ }
85
+ if !bytes.Equal(got, msg) {
86
+ t.Fatalf("expected %q, got %q", msg, got)
87
+ }
88
+ })
89
+
90
+ if allocs != 0 {
91
+ t.Fatalf("expected queued ReadOne to avoid allocations, got %f", allocs)
92
+ }
93
+ }
@@ -59,29 +59,31 @@ func (r *ServerRPC) HandlePacket(msg *Packet) error {
59
59
  func (r *ServerRPC) HandleCallStart(pkt *CallStart) error {
60
60
  var err error
61
61
 
62
- r.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
63
- // process start: method and service
64
- if r.method != "" || r.service != "" {
65
- err = errors.New("call start must be sent only once")
66
- return
67
- }
68
- if r.dataClosed {
69
- err = ErrCompleted
70
- return
71
- }
62
+ locked := r.bcast.Lock()
63
+ // process start: method and service
64
+ if r.method != "" || r.service != "" {
65
+ err = errors.New("call start must be sent only once")
66
+ locked.Unlock()
67
+ return err
68
+ }
69
+ if r.dataClosed {
70
+ err = ErrCompleted
71
+ locked.Unlock()
72
+ return err
73
+ }
72
74
 
73
- service, method := pkt.GetRpcService(), pkt.GetRpcMethod()
74
- r.service, r.method = service, method
75
+ service, method := pkt.GetRpcService(), pkt.GetRpcMethod()
76
+ r.service, r.method = service, method
75
77
 
76
- // process first data packet, if included
77
- if data := pkt.GetData(); len(data) != 0 || pkt.GetDataIsZero() {
78
- r.dataQueue = append(r.dataQueue, data)
79
- }
78
+ // process first data packet, if included
79
+ if data := pkt.GetData(); len(data) != 0 || pkt.GetDataIsZero() {
80
+ r.dataQueue = append(r.dataQueue, data)
81
+ }
80
82
 
81
- // invoke the rpc
82
- broadcast()
83
- go r.invokeRPC(service, method)
84
- })
83
+ // invoke the rpc
84
+ locked.Broadcast()
85
+ go r.invokeRPC(service, method)
86
+ locked.Unlock()
85
87
 
86
88
  return err
87
89
  }
@@ -9,6 +9,7 @@ import {
9
9
  ChannelStream,
10
10
  combineUint8ArrayListTransform,
11
11
  ChannelStreamOpts,
12
+ Packet,
12
13
  } from '../srpc/index.js'
13
14
  import {
14
15
  EchoerDefinition,
@@ -77,6 +78,68 @@ describe('srpc server', () => {
77
78
  await runRpcStreamTest(client)
78
79
  })
79
80
 
81
+ it('keeps detached server-streaming responses open after request source completes', async () => {
82
+ const mux = createMux()
83
+ const response = new TextEncoder().encode('delayed init')
84
+ mux.registerLookupMethod(async (service, method) => {
85
+ if (service !== 'test.ResourceService' || method !== 'ResourceClient') {
86
+ return null
87
+ }
88
+ return async (_dataSource, dataSink) => {
89
+ await dataSink(
90
+ (async function* () {
91
+ await new Promise((resolve) => setTimeout(resolve, 10))
92
+ yield response
93
+ })(),
94
+ )
95
+ }
96
+ })
97
+
98
+ const server = new Server(mux.lookupMethod)
99
+ const firstResponse = new Promise<Packet>((resolve, reject) => {
100
+ server.handlePacketStream({
101
+ source: (async function* () {
102
+ yield Packet.toBinary({
103
+ body: {
104
+ case: 'callStart',
105
+ value: {
106
+ rpcService: 'test.ResourceService',
107
+ rpcMethod: 'ResourceClient',
108
+ data: new Uint8Array(0),
109
+ dataIsZero: true,
110
+ },
111
+ },
112
+ })
113
+ })(),
114
+ sink: async (source) => {
115
+ try {
116
+ for await (const packetData of source) {
117
+ const packet = Packet.fromBinary(packetData)
118
+ if (packet.body?.case === 'callData') {
119
+ resolve(packet)
120
+ return
121
+ }
122
+ }
123
+ reject(new Error('server response stream ended before call data'))
124
+ } catch (err) {
125
+ reject(err as Error)
126
+ }
127
+ },
128
+ })
129
+ })
130
+
131
+ const packet = await promiseWithTimeout(
132
+ firstResponse,
133
+ 'detached server-streaming response',
134
+ )
135
+ const body = packet.body
136
+ expect(body?.case).toBe('callData')
137
+ if (body?.case !== 'callData' || !body.value?.data) {
138
+ throw new Error('expected callData packet')
139
+ }
140
+ expect([...body.value.data]).toEqual([...response])
141
+ })
142
+
80
143
  it('removes abort listeners after a request completes', async () => {
81
144
  const controller = new AbortController()
82
145
  const removeEventListener = vi.spyOn(
@@ -114,3 +177,12 @@ describe('srpc server', () => {
114
177
  expect(port2.onmessage).toBe(null)
115
178
  })
116
179
  })
180
+
181
+ function promiseWithTimeout<T>(promise: Promise<T>, label: string): Promise<T> {
182
+ return Promise.race([
183
+ promise,
184
+ new Promise<T>((_, reject) => {
185
+ setTimeout(() => reject(new Error(`timed out waiting for ${label}`)), 100)
186
+ }),
187
+ ])
188
+ }