starpc 0.49.9 → 0.49.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/go.mod CHANGED
@@ -3,9 +3,9 @@ module github.com/aperturerobotics/starpc
3
3
  go 1.25.0
4
4
 
5
5
  require (
6
- github.com/aperturerobotics/common v0.33.0 // latest
6
+ github.com/aperturerobotics/common v0.33.1-0.20260516193515-675cfc5a0c12 // latest
7
7
  github.com/aperturerobotics/protobuf-go-lite v0.13.0 // latest
8
- github.com/aperturerobotics/util v1.34.3 // latest
8
+ github.com/aperturerobotics/util v1.34.5 // latest
9
9
  )
10
10
 
11
11
  require (
@@ -29,5 +29,5 @@ require (
29
29
  github.com/tetratelabs/wazero v1.11.0 // indirect
30
30
  github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 // indirect
31
31
  golang.org/x/mod v0.35.0 // indirect
32
- golang.org/x/sys v0.43.0 // indirect
32
+ golang.org/x/sys v0.44.0 // indirect
33
33
  )
package/go.sum CHANGED
@@ -2,8 +2,8 @@ github.com/aperturerobotics/abseil-cpp v0.0.0-20260131110040-4bb56e2f9017 h1:3U7
2
2
  github.com/aperturerobotics/abseil-cpp v0.0.0-20260131110040-4bb56e2f9017/go.mod h1:lNSJTKECIUFAnfeSqy01kXYTYe1BHubW7198jNX3nEw=
3
3
  github.com/aperturerobotics/cli v1.1.0 h1:7a+YRC+EY3npAnTzhHV5gLCiw91KS0Ts3XwLILGOsT8=
4
4
  github.com/aperturerobotics/cli v1.1.0/go.mod h1:M7BFP9wow5ytTzMyJQOOO991fGfsUqdTI7gGEsHfTQ8=
5
- github.com/aperturerobotics/common v0.33.0 h1:IheETbaQPmvUpkm6Z+/1jbuAQOXZF5REnRRMXTaIeVk=
6
- github.com/aperturerobotics/common v0.33.0/go.mod h1:xabIJydWovkzjs5YZD8ru/BgFTAXekgHwV8DrTl3R2w=
5
+ github.com/aperturerobotics/common v0.33.1-0.20260516193515-675cfc5a0c12 h1:KQiFEDyu5//G8JZ3yTN92kF2P3F+LHjSKsQikgROncU=
6
+ github.com/aperturerobotics/common v0.33.1-0.20260516193515-675cfc5a0c12/go.mod h1:pW7BhBpKzVrTxN8XIz6jb3MJERZ15GDxaGQ7jcIl0Xg=
7
7
  github.com/aperturerobotics/go-protoc-gen-prost v0.0.0-20260329113538-218ccd8f20e0 h1:6/3RSSlPEQ6LeidslB1ZCJkxW+MnfYDkvdWMDklDXw4=
8
8
  github.com/aperturerobotics/go-protoc-gen-prost v0.0.0-20260329113538-218ccd8f20e0/go.mod h1:OBb/beWmr/pDIZAUfi86j/4tBh2v5ctTxKMqSnh9c/4=
9
9
  github.com/aperturerobotics/go-protoc-wasi v0.0.0-20260329113540-600516012db3 h1:lp+V8RYcBwTX1p81swkpZn5fhw1wn2xLorzETIxRyZQ=
@@ -16,8 +16,8 @@ github.com/aperturerobotics/protobuf v0.0.0-20260203024654-8201686529c4 h1:4Dy3B
16
16
  github.com/aperturerobotics/protobuf v0.0.0-20260203024654-8201686529c4/go.mod h1:tMgO7y6SJo/d9ZcvrpNqIQtdYT9de+QmYaHOZ4KnhOg=
17
17
  github.com/aperturerobotics/protobuf-go-lite v0.13.0 h1:jEvCJhHaJEikDY/va2AUnS0DOb/0n82aISLAqxSh4Sk=
18
18
  github.com/aperturerobotics/protobuf-go-lite v0.13.0/go.mod h1:lGH3s5ArCTXKI4wJdlNpaybUtwSjfAG0vdWjxOfMcF8=
19
- github.com/aperturerobotics/util v1.34.3 h1:9lFYJovGlAHYX66aWKVzfoVzYox13P054d/Dy8T/GRg=
20
- github.com/aperturerobotics/util v1.34.3/go.mod h1:xtE2hwpgKGaW0TLx01+9xLsG+rYvhygQ+JCuQeAbvME=
19
+ github.com/aperturerobotics/util v1.34.5 h1:007MaOJrrsiGm5o1c8Tt7p8nVwUAxkM6pmGflrBww/U=
20
+ github.com/aperturerobotics/util v1.34.5/go.mod h1:mDe7WnncVuV7yjeeVSsagyfrw4xfncu7d+f0+d70niY=
21
21
  github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
22
22
  github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
23
23
  github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
@@ -40,8 +40,8 @@ github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAz
40
40
  github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
41
41
  golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM=
42
42
  golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU=
43
- golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
44
- golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
43
+ golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ=
44
+ golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
45
45
  google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
46
46
  google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
47
47
  gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "starpc",
3
- "version": "0.49.9",
3
+ "version": "0.49.11",
4
4
  "description": "Streaming protobuf RPC service protocol over any two-way channel.",
5
5
  "license": "MIT",
6
6
  "author": {
@@ -38,22 +38,22 @@ func (r *ClientRPC) Start(writer PacketWriter, writeFirstMsg bool, firstMsg []by
38
38
 
39
39
  var firstMsgEmpty bool
40
40
  var err error
41
- r.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
42
- r.writer = writer
41
+ locked := r.bcast.Lock()
42
+ r.writer = writer
43
43
 
44
- if writeFirstMsg {
45
- firstMsgEmpty = len(firstMsg) == 0
46
- }
44
+ if writeFirstMsg {
45
+ firstMsgEmpty = len(firstMsg) == 0
46
+ }
47
47
 
48
- pkt := NewCallStartPacket(r.service, r.method, firstMsg, firstMsgEmpty)
49
- err = writer.WritePacket(pkt)
50
- if err != nil {
51
- r.ctxCancel()
52
- _ = writer.Close()
53
- }
48
+ pkt := NewCallStartPacket(r.service, r.method, firstMsg, firstMsgEmpty)
49
+ err = writer.WritePacket(pkt)
50
+ if err != nil {
51
+ r.ctxCancel()
52
+ _ = writer.Close()
53
+ }
54
54
 
55
- broadcast()
56
- })
55
+ locked.Broadcast()
56
+ locked.Unlock()
57
57
 
58
58
  return err
59
59
  }
@@ -69,14 +69,14 @@ func (r *ClientRPC) HandlePacketData(data []byte) error {
69
69
 
70
70
  // HandleStreamClose handles the stream closing optionally w/ an error.
71
71
  func (r *ClientRPC) HandleStreamClose(closeErr error) {
72
- r.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
73
- if closeErr != nil && r.remoteErr == nil {
74
- r.remoteErr = closeErr
75
- }
76
- r.dataClosed = true
77
- r.ctxCancel()
78
- broadcast()
79
- })
72
+ locked := r.bcast.Lock()
73
+ if closeErr != nil && r.remoteErr == nil {
74
+ r.remoteErr = closeErr
75
+ }
76
+ r.dataClosed = true
77
+ r.ctxCancel()
78
+ locked.Broadcast()
79
+ locked.Unlock()
80
80
  }
81
81
 
82
82
  // HandlePacket handles an incoming parsed message packet.
@@ -108,11 +108,11 @@ func (r *ClientRPC) HandleCallStart(pkt *CallStart) error {
108
108
 
109
109
  // Close releases any resources held by the ClientRPC.
110
110
  func (r *ClientRPC) Close() {
111
- r.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
112
- // call did not start yet if writer is nil.
113
- if r.writer != nil {
114
- _ = r.WriteCallCancel()
115
- r.closeLocked(broadcast)
116
- }
117
- })
111
+ locked := r.bcast.Lock()
112
+ // call did not start yet if writer is nil.
113
+ if r.writer != nil {
114
+ _ = r.WriteCallCancel()
115
+ r.closeLocked(&locked)
116
+ }
117
+ locked.Unlock()
118
118
  }
@@ -52,10 +52,12 @@ func (c *commonRPC) Wait(ctx context.Context) error {
52
52
  var err error
53
53
  var waitCh <-chan struct{}
54
54
  var rpcCtx context.Context
55
- c.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
56
- rpcCtx, err = c.ctx, c.remoteErr
57
- waitCh = getWaitCh()
58
- })
55
+ locked := c.bcast.Lock()
56
+ rpcCtx, err = c.ctx, c.remoteErr
57
+ if err == nil && rpcCtx.Err() == nil {
58
+ waitCh = locked.WaitCh()
59
+ }
60
+ locked.Unlock()
59
61
 
60
62
  if err != nil {
61
63
  return err
@@ -78,43 +80,37 @@ func (c *commonRPC) Wait(ctx context.Context) error {
78
80
  //
79
81
  // returns io.EOF if the stream ended without a packet.
80
82
  func (c *commonRPC) ReadOne() ([]byte, error) {
81
- var hasMsg bool
82
- var msg []byte
83
- var err error
84
83
  var ctxDone bool
85
84
  for {
86
85
  var waitCh <-chan struct{}
87
- c.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
88
- if ctxDone && !c.dataClosed {
89
- // context must have been canceled locally
90
- c.closeLocked(broadcast)
91
- err = context.Canceled
92
- return
93
- }
94
-
95
- if len(c.dataQueue) != 0 {
96
- msg = c.dataQueue[0]
97
- hasMsg = true
98
- c.dataQueue[0] = nil
99
- c.dataQueue = c.dataQueue[1:]
100
- } else if c.dataClosed || c.remoteErr != nil {
101
- err = c.remoteErr
102
- if err == nil {
103
- err = io.EOF
104
- }
105
- }
106
-
107
- waitCh = getWaitCh()
108
- })
86
+ locked := c.bcast.Lock()
87
+ if ctxDone && !c.dataClosed {
88
+ // context must have been canceled locally
89
+ c.closeLocked(&locked)
90
+ locked.Unlock()
91
+ return nil, context.Canceled
92
+ }
109
93
 
110
- if hasMsg {
94
+ if len(c.dataQueue) != 0 {
95
+ msg := c.dataQueue[0]
96
+ c.dataQueue[0] = nil
97
+ c.dataQueue = c.dataQueue[1:]
98
+ locked.Unlock()
111
99
  return msg, nil
112
100
  }
113
101
 
114
- if err != nil {
102
+ if c.dataClosed || c.remoteErr != nil {
103
+ err := c.remoteErr
104
+ if err == nil {
105
+ err = io.EOF
106
+ }
107
+ locked.Unlock()
115
108
  return nil, err
116
109
  }
117
110
 
111
+ waitCh = locked.WaitCh()
112
+ locked.Unlock()
113
+
118
114
  select {
119
115
  case <-c.ctx.Done():
120
116
  ctxDone = true
@@ -146,15 +142,23 @@ func (c *commonRPC) WriteCallData(data []byte, dataIsZero, complete bool, err er
146
142
 
147
143
  // HandleStreamClose handles the incoming stream closing w/ optional error.
148
144
  func (c *commonRPC) HandleStreamClose(closeErr error) {
149
- c.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
150
- if closeErr != nil && c.remoteErr == nil {
151
- c.remoteErr = closeErr
152
- }
153
- c.dataClosed = true
154
- c.ctxCancel()
155
- _ = c.writer.Close()
156
- broadcast()
157
- })
145
+ var writer PacketWriter
146
+ locked := c.bcast.Lock()
147
+ if c.dataClosed {
148
+ locked.Unlock()
149
+ return
150
+ }
151
+ if closeErr != nil && c.remoteErr == nil {
152
+ c.remoteErr = closeErr
153
+ }
154
+ c.dataClosed = true
155
+ c.ctxCancel()
156
+ writer = c.writer
157
+ locked.Broadcast()
158
+ locked.Unlock()
159
+ if writer != nil {
160
+ _ = writer.Close()
161
+ }
158
162
  }
159
163
 
160
164
  // HandleCallCancel handles the call cancel packet.
@@ -166,34 +170,33 @@ func (c *commonRPC) HandleCallCancel() error {
166
170
  // HandleCallData handles the call data packet.
167
171
  func (c *commonRPC) HandleCallData(pkt *CallData) error {
168
172
  var err error
169
- c.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
170
- if c.dataClosed {
171
- // If the packet is just indicating the call is complete, ignore it.
172
- if pkt.GetComplete() {
173
- return
174
- }
175
-
173
+ locked := c.bcast.Lock()
174
+ if c.dataClosed {
175
+ // If the packet is just indicating the call is complete, ignore it.
176
+ if !pkt.GetComplete() {
176
177
  // Otherwise, return ErrCompleted (unexpected packet).
177
178
  err = ErrCompleted
178
- return
179
179
  }
180
+ locked.Unlock()
181
+ return err
182
+ }
180
183
 
181
- if data := pkt.GetData(); len(data) != 0 || pkt.GetDataIsZero() {
182
- c.dataQueue = append(c.dataQueue, data)
183
- }
184
+ if data := pkt.GetData(); len(data) != 0 || pkt.GetDataIsZero() {
185
+ c.dataQueue = append(c.dataQueue, data)
186
+ }
184
187
 
185
- complete := pkt.GetComplete()
186
- if err := pkt.GetError(); len(err) != 0 {
187
- complete = true
188
- c.remoteErr = errors.New(err)
189
- }
188
+ complete := pkt.GetComplete()
189
+ if pktErr := pkt.GetError(); len(pktErr) != 0 {
190
+ complete = true
191
+ c.remoteErr = errors.New(pktErr)
192
+ }
190
193
 
191
- if complete {
192
- c.dataClosed = true
193
- }
194
+ if complete {
195
+ c.dataClosed = true
196
+ }
194
197
 
195
- broadcast()
196
- })
198
+ locked.Broadcast()
199
+ locked.Unlock()
197
200
 
198
201
  return err
199
202
  }
@@ -209,13 +212,16 @@ func (c *commonRPC) WriteCallCancel() error {
209
212
  }
210
213
 
211
214
  // closeLocked releases resources held by the RPC.
212
- func (c *commonRPC) closeLocked(broadcast func()) {
215
+ func (c *commonRPC) closeLocked(locked *broadcast.Locked) {
216
+ if c.dataClosed {
217
+ return
218
+ }
213
219
  c.dataClosed = true
214
220
  c.localCompleted.Store(true)
215
221
  if c.remoteErr == nil {
216
222
  c.remoteErr = context.Canceled
217
223
  }
218
224
  _ = c.writer.Close()
219
- broadcast()
225
+ locked.Broadcast()
220
226
  c.ctxCancel()
221
227
  }
@@ -0,0 +1,93 @@
1
+ package srpc
2
+
3
+ import (
4
+ "bytes"
5
+ "context"
6
+ "io"
7
+ "sync/atomic"
8
+ "testing"
9
+ )
10
+
11
+ type closeCountingPacketWriter struct {
12
+ closed atomic.Int32
13
+ }
14
+
15
+ type closeCallbackPacketWriter struct {
16
+ closeFn func()
17
+ }
18
+
19
+ func (w *closeCountingPacketWriter) WritePacket(*Packet) error {
20
+ return nil
21
+ }
22
+
23
+ func (w *closeCountingPacketWriter) Close() error {
24
+ w.closed.Add(1)
25
+ return nil
26
+ }
27
+
28
+ func (w *closeCallbackPacketWriter) WritePacket(*Packet) error {
29
+ return nil
30
+ }
31
+
32
+ func (w *closeCallbackPacketWriter) Close() error {
33
+ if w.closeFn != nil {
34
+ w.closeFn()
35
+ }
36
+ return nil
37
+ }
38
+
39
+ func TestCommonRPCHandleStreamCloseIdempotent(t *testing.T) {
40
+ writer := &closeCountingPacketWriter{}
41
+ rpc := NewServerRPC(context.Background(), InvokerFunc(nil), writer)
42
+
43
+ rpc.HandleStreamClose(io.EOF)
44
+ rpc.HandleStreamClose(context.Canceled)
45
+
46
+ if got := writer.closed.Load(); got != 1 {
47
+ t.Fatalf("expected writer closed once, got %d", got)
48
+ }
49
+ }
50
+
51
+ func TestCommonRPCHandleStreamCloseClosesWriterOutsideBroadcastLock(t *testing.T) {
52
+ var rpc *ServerRPC
53
+ writerClosedOutsideLock := false
54
+ writer := &closeCallbackPacketWriter{
55
+ closeFn: func() {
56
+ locked, ok := rpc.bcast.TryLock()
57
+ if ok {
58
+ locked.Unlock()
59
+ }
60
+ writerClosedOutsideLock = ok
61
+ },
62
+ }
63
+ rpc = NewServerRPC(context.Background(), InvokerFunc(nil), writer)
64
+
65
+ rpc.HandleStreamClose(io.EOF)
66
+
67
+ if !writerClosedOutsideLock {
68
+ t.Fatal("expected writer close outside broadcast lock")
69
+ }
70
+ }
71
+
72
+ func TestCommonRPCReadOneQueuedDoesNotAllocate(t *testing.T) {
73
+ msg := []byte("message")
74
+ queue := make([][]byte, 1)
75
+ rpc := NewServerRPC(context.Background(), InvokerFunc(nil), &closeCountingPacketWriter{})
76
+
77
+ allocs := testing.AllocsPerRun(1000, func() {
78
+ queue[0] = msg
79
+ rpc.dataQueue = queue
80
+
81
+ got, err := rpc.ReadOne()
82
+ if err != nil {
83
+ t.Fatalf("read one: %v", err)
84
+ }
85
+ if !bytes.Equal(got, msg) {
86
+ t.Fatalf("expected %q, got %q", msg, got)
87
+ }
88
+ })
89
+
90
+ if allocs != 0 {
91
+ t.Fatalf("expected queued ReadOne to avoid allocations, got %f", allocs)
92
+ }
93
+ }
package/srpc/packet-rw.go CHANGED
@@ -10,8 +10,31 @@ import (
10
10
  "github.com/pkg/errors"
11
11
  )
12
12
 
13
- // maxMessageSize is the max message size in bytes
14
- const maxMessageSize = 10_000_000
13
+ const (
14
+ // maxMessageSize is the max message size in bytes.
15
+ maxMessageSize = 10_000_000
16
+ // readBufferSize is the packet read scratch buffer size.
17
+ readBufferSize = 2048
18
+ // pooledWriteBufferMaxSize is the largest outbound frame buffer to pool.
19
+ pooledWriteBufferMaxSize = 64 * 1024
20
+ )
21
+
22
+ var (
23
+ readBufferPool = sync.Pool{
24
+ New: func() any {
25
+ return new([readBufferSize]byte)
26
+ },
27
+ }
28
+ writeBufferPool = sync.Pool{
29
+ New: func() any {
30
+ return new(writeBuffer)
31
+ },
32
+ }
33
+ )
34
+
35
+ type writeBuffer struct {
36
+ data []byte
37
+ }
15
38
 
16
39
  // PacketReadWriter reads and writes packets from a io.ReadWriter.
17
40
  // Uses a LittleEndian uint32 length prefix.
@@ -46,7 +69,9 @@ func (r *PacketReadWriter) WritePacket(p *Packet) error {
46
69
  return errors.Errorf("message size %v greater than maximum %v", msgSize, maxMessageSize)
47
70
  }
48
71
 
49
- data := make([]byte, 4+msgSize)
72
+ writeBuf := getWriteBuffer(4 + msgSize)
73
+ defer putWriteBuffer(writeBuf)
74
+ data := writeBuf.data
50
75
  binary.LittleEndian.PutUint32(data, uint32(msgSize)) //nolint:gosec
51
76
 
52
77
  _, err := p.MarshalToSizedBufferVT(data[4:])
@@ -56,10 +81,13 @@ func (r *PacketReadWriter) WritePacket(p *Packet) error {
56
81
 
57
82
  var written, n int
58
83
  for written < len(data) {
59
- n, err = r.rw.Write(data)
84
+ n, err = r.rw.Write(data[written:])
60
85
  if err != nil {
61
86
  return err
62
87
  }
88
+ if n == 0 {
89
+ return io.ErrShortWrite
90
+ }
63
91
  written += n
64
92
  }
65
93
 
@@ -81,7 +109,9 @@ func (r *PacketReadWriter) ReadPump(cb PacketDataHandler, closed CloseHandler) {
81
109
  // Does not handle closing the stream, use ReadPump instead.
82
110
  func (r *PacketReadWriter) ReadToHandler(cb PacketDataHandler) error {
83
111
  var currLen uint32
84
- buf := make([]byte, 2048)
112
+ bufPtr := readBufferPool.Get().(*[readBufferSize]byte)
113
+ defer readBufferPool.Put(bufPtr)
114
+ buf := bufPtr[:]
85
115
  isOpen := true
86
116
 
87
117
  for isOpen {
@@ -150,5 +180,25 @@ func (r *PacketReadWriter) readLengthPrefix(b []byte) uint32 {
150
180
  return binary.LittleEndian.Uint32(b)
151
181
  }
152
182
 
183
+ func getWriteBuffer(size int) *writeBuffer {
184
+ if size > pooledWriteBufferMaxSize {
185
+ return &writeBuffer{data: make([]byte, size)}
186
+ }
187
+ buf := writeBufferPool.Get().(*writeBuffer)
188
+ if cap(buf.data) < size {
189
+ buf.data = make([]byte, size)
190
+ }
191
+ buf.data = buf.data[:size]
192
+ return buf
193
+ }
194
+
195
+ func putWriteBuffer(buf *writeBuffer) {
196
+ if cap(buf.data) <= pooledWriteBufferMaxSize {
197
+ clear(buf.data)
198
+ buf.data = buf.data[:0]
199
+ writeBufferPool.Put(buf)
200
+ }
201
+ }
202
+
153
203
  // _ is a type assertion
154
204
  var _ PacketWriter = (*PacketReadWriter)(nil)
@@ -0,0 +1,68 @@
1
+ package srpc
2
+
3
+ import (
4
+ "bytes"
5
+ "encoding/binary"
6
+ "io"
7
+ "testing"
8
+ )
9
+
10
+ type chunkedReadWriteCloser struct {
11
+ bytes.Buffer
12
+ maxWrite int
13
+ }
14
+
15
+ func (c *chunkedReadWriteCloser) Read([]byte) (int, error) {
16
+ return 0, io.EOF
17
+ }
18
+
19
+ func (c *chunkedReadWriteCloser) Write(p []byte) (int, error) {
20
+ if len(p) > c.maxWrite {
21
+ p = p[:c.maxWrite]
22
+ }
23
+ return c.Buffer.Write(p)
24
+ }
25
+
26
+ func (c *chunkedReadWriteCloser) Close() error {
27
+ return nil
28
+ }
29
+
30
+ func TestPacketReadWriterWritePacketHandlesShortWrites(t *testing.T) {
31
+ pkt := NewCallDataPacket([]byte("packet payload"), false, true, nil)
32
+ size := pkt.SizeVT()
33
+ want := make([]byte, 4+size)
34
+ binary.LittleEndian.PutUint32(want, uint32(size)) //nolint:gosec
35
+ if _, err := pkt.MarshalToSizedBufferVT(want[4:]); err != nil {
36
+ t.Fatal(err)
37
+ }
38
+
39
+ rwc := &chunkedReadWriteCloser{maxWrite: 3}
40
+ if err := NewPacketReadWriter(rwc).WritePacket(pkt); err != nil {
41
+ t.Fatal(err)
42
+ }
43
+ if got := rwc.Bytes(); !bytes.Equal(got, want) {
44
+ t.Fatalf("written packet mismatch:\ngot %x\nwant %x", got, want)
45
+ }
46
+ }
47
+
48
+ func TestPacketUnmarshalCopiesByteFields(t *testing.T) {
49
+ want := []byte("stable data")
50
+ srcPkt := NewCallDataPacket(want, false, true, nil)
51
+ data, err := srcPkt.MarshalVT()
52
+ if err != nil {
53
+ t.Fatal(err)
54
+ }
55
+
56
+ var pkt Packet
57
+ if err := pkt.UnmarshalVT(data); err != nil {
58
+ t.Fatal(err)
59
+ }
60
+ for i := range data {
61
+ data[i] = 0xff
62
+ }
63
+
64
+ got := pkt.GetCallData().GetData()
65
+ if !bytes.Equal(got, want) {
66
+ t.Fatalf("unmarshal retained source bytes: got %q want %q", got, want)
67
+ }
68
+ }
@@ -59,29 +59,31 @@ func (r *ServerRPC) HandlePacket(msg *Packet) error {
59
59
  func (r *ServerRPC) HandleCallStart(pkt *CallStart) error {
60
60
  var err error
61
61
 
62
- r.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
63
- // process start: method and service
64
- if r.method != "" || r.service != "" {
65
- err = errors.New("call start must be sent only once")
66
- return
67
- }
68
- if r.dataClosed {
69
- err = ErrCompleted
70
- return
71
- }
62
+ locked := r.bcast.Lock()
63
+ // process start: method and service
64
+ if r.method != "" || r.service != "" {
65
+ err = errors.New("call start must be sent only once")
66
+ locked.Unlock()
67
+ return err
68
+ }
69
+ if r.dataClosed {
70
+ err = ErrCompleted
71
+ locked.Unlock()
72
+ return err
73
+ }
72
74
 
73
- service, method := pkt.GetRpcService(), pkt.GetRpcMethod()
74
- r.service, r.method = service, method
75
+ service, method := pkt.GetRpcService(), pkt.GetRpcMethod()
76
+ r.service, r.method = service, method
75
77
 
76
- // process first data packet, if included
77
- if data := pkt.GetData(); len(data) != 0 || pkt.GetDataIsZero() {
78
- r.dataQueue = append(r.dataQueue, data)
79
- }
78
+ // process first data packet, if included
79
+ if data := pkt.GetData(); len(data) != 0 || pkt.GetDataIsZero() {
80
+ r.dataQueue = append(r.dataQueue, data)
81
+ }
80
82
 
81
- // invoke the rpc
82
- broadcast()
83
- go r.invokeRPC(service, method)
84
- })
83
+ // invoke the rpc
84
+ locked.Broadcast()
85
+ go r.invokeRPC(service, method)
86
+ locked.Unlock()
85
87
 
86
88
  return err
87
89
  }