starpc 0.14.0 → 0.15.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,9 @@
1
1
  import type { TsProtoRpc } from './ts-proto-rpc.js';
2
2
  import type { OpenStreamFunc } from './stream.js';
3
3
  export declare class Client implements TsProtoRpc {
4
- private openStreamFn;
5
- private _openStreamFn?;
4
+ private openStreamCtr;
6
5
  constructor(openStreamFn?: OpenStreamFunc);
7
- setOpenStreamFn(openStreamFn?: OpenStreamFunc): Promise<OpenStreamFunc>;
8
- private initOpenStreamFn;
6
+ setOpenStreamFn(openStreamFn?: OpenStreamFunc): void;
9
7
  request(service: string, method: string, data: Uint8Array): Promise<Uint8Array>;
10
8
  clientStreamingRequest(service: string, method: string, data: AsyncIterable<Uint8Array>): Promise<Uint8Array>;
11
9
  serverStreamingRequest(service: string, method: string, data: Uint8Array): AsyncIterable<Uint8Array>;
@@ -4,43 +4,15 @@ import { ClientRPC } from './client-rpc.js';
4
4
  import { writeToPushable } from './pushable.js';
5
5
  import { decodePacketSource, encodePacketSource, parseLengthPrefixTransform, prependLengthPrefixTransform, } from './packet.js';
6
6
  import { combineUint8ArrayListTransform } from './array-list.js';
7
+ import { OpenStreamCtr } from './open-stream-ctr.js';
7
8
  // Client implements the ts-proto Rpc interface with the drpcproto protocol.
8
9
  export class Client {
9
10
  constructor(openStreamFn) {
10
- this.openStreamFn = this.setOpenStreamFn(openStreamFn);
11
+ this.openStreamCtr = new OpenStreamCtr(openStreamFn || undefined);
11
12
  }
12
13
  // setOpenStreamFn updates the openStreamFn for the Client.
13
14
  setOpenStreamFn(openStreamFn) {
14
- if (this._openStreamFn) {
15
- if (openStreamFn) {
16
- this._openStreamFn(openStreamFn);
17
- this._openStreamFn = undefined;
18
- }
19
- }
20
- else {
21
- if (openStreamFn) {
22
- this.openStreamFn = Promise.resolve(openStreamFn);
23
- }
24
- else {
25
- this.initOpenStreamFn();
26
- }
27
- }
28
- return this.openStreamFn;
29
- }
30
- // initOpenStreamFn creates the empty Promise for openStreamFn.
31
- initOpenStreamFn() {
32
- const openPromise = new Promise((resolve, reject) => {
33
- this._openStreamFn = (conn, err) => {
34
- if (err) {
35
- reject(err);
36
- }
37
- else if (conn) {
38
- resolve(conn);
39
- }
40
- };
41
- });
42
- this.openStreamFn = openPromise;
43
- return this.openStreamFn;
15
+ this.openStreamCtr.set(openStreamFn || undefined);
44
16
  }
45
17
  // request starts a non-streaming request.
46
18
  async request(service, method, data) {
@@ -98,7 +70,7 @@ export class Client {
98
70
  // throws any error starting the rpc call
99
71
  // if data == null and data.length == 0, sends a separate data packet.
100
72
  async startRpc(rpcService, rpcMethod, data) {
101
- const openStreamFn = await this.openStreamFn;
73
+ const openStreamFn = await this.openStreamCtr.wait();
102
74
  const conn = await openStreamFn();
103
75
  const call = new ClientRPC(rpcService, rpcMethod);
104
76
  pipe(conn, parseLengthPrefixTransform(), combineUint8ArrayListTransform(), decodePacketSource, call, encodePacketSource, prependLengthPrefixTransform(), conn);
@@ -7,4 +7,6 @@ export { Packet, CallStart, CallData } from './rpcproto.pb.js';
7
7
  export { Mux, StaticMux, createMux } from './mux.js';
8
8
  export { BroadcastChannelDuplex, newBroadcastChannelDuplex, BroadcastChannelConn, } from './broadcast-channel.js';
9
9
  export { MessagePortIterable, newMessagePortIterable, MessagePortConn, } from './message-port.js';
10
+ export { ValueCtr } from './value-ctr.js';
11
+ export { OpenStreamCtr } from './open-stream-ctr.js';
10
12
  export { writeToPushable } from './pushable';
@@ -6,4 +6,6 @@ export { Packet, CallStart, CallData } from './rpcproto.pb.js';
6
6
  export { StaticMux, createMux } from './mux.js';
7
7
  export { BroadcastChannelDuplex, newBroadcastChannelDuplex, BroadcastChannelConn, } from './broadcast-channel.js';
8
8
  export { MessagePortIterable, newMessagePortIterable, MessagePortConn, } from './message-port.js';
9
+ export { ValueCtr } from './value-ctr.js';
10
+ export { OpenStreamCtr } from './open-stream-ctr.js';
9
11
  export { writeToPushable } from './pushable';
@@ -0,0 +1,6 @@
1
+ import { OpenStreamFunc } from './stream.js';
2
+ import { ValueCtr } from './value-ctr.js';
3
+ export declare class OpenStreamCtr extends ValueCtr<OpenStreamFunc> {
4
+ constructor(openStreamFn?: OpenStreamFunc);
5
+ get openStreamFunc(): OpenStreamFunc;
6
+ }
@@ -0,0 +1,17 @@
1
+ import { ValueCtr } from './value-ctr.js';
2
+ // OpenStreamCtr contains an OpenStream func which can be awaited.
3
+ export class OpenStreamCtr extends ValueCtr {
4
+ constructor(openStreamFn) {
5
+ super(openStreamFn);
6
+ }
7
+ // openStreamFunc returns an OpenStreamFunc which waits for the underlying OpenStreamFunc.
8
+ get openStreamFunc() {
9
+ return async () => {
10
+ let openFn = this.value;
11
+ if (!openFn) {
12
+ openFn = await this.wait();
13
+ }
14
+ return openFn();
15
+ };
16
+ }
17
+ }
@@ -0,0 +1,9 @@
1
+ export declare class ValueCtr<T> {
2
+ private _value;
3
+ private _waiters;
4
+ constructor(initialValue?: T);
5
+ get value(): T | undefined;
6
+ wait(): Promise<T>;
7
+ waitWithCb(cb: (val: T) => void): void;
8
+ set(val: T | undefined): void;
9
+ }
@@ -0,0 +1,44 @@
1
+ // ValueCtr contains a value that can be set asynchronously.
2
+ export class ValueCtr {
3
+ constructor(initialValue) {
4
+ this._value = initialValue || undefined;
5
+ this._waiters = [];
6
+ }
7
+ // value returns the current value.
8
+ get value() {
9
+ return this._value;
10
+ }
11
+ // wait waits for the value to not be undefined.
12
+ async wait() {
13
+ const currVal = this._value;
14
+ if (currVal !== undefined) {
15
+ return currVal;
16
+ }
17
+ return new Promise((resolve) => {
18
+ this.waitWithCb((val) => {
19
+ resolve(val);
20
+ });
21
+ });
22
+ }
23
+ // waitWithCb adds a callback to be called when the value is not undefined.
24
+ waitWithCb(cb) {
25
+ if (cb) {
26
+ this._waiters.push(cb);
27
+ }
28
+ }
29
+ // set sets the value and calls the callbacks.
30
+ set(val) {
31
+ this._value = val;
32
+ if (val === undefined) {
33
+ return;
34
+ }
35
+ const waiters = this._waiters;
36
+ if (waiters.length === 0) {
37
+ return;
38
+ }
39
+ this._waiters = [];
40
+ for (const waiter of waiters) {
41
+ waiter(val);
42
+ }
43
+ }
44
+ }
package/go.mod CHANGED
@@ -9,6 +9,7 @@ require (
9
9
  )
10
10
 
11
11
  require (
12
+ github.com/aperturerobotics/util v0.0.0-20221125031036-d86bb72019f6
12
13
  github.com/libp2p/go-libp2p v0.23.3-0.20221109121032-c334288f8fe4
13
14
  github.com/libp2p/go-yamux/v4 v4.0.1-0.20220919134236-1c09f2ab3ec1
14
15
  github.com/sirupsen/logrus v1.9.0
package/go.sum CHANGED
@@ -1,3 +1,5 @@
1
+ github.com/aperturerobotics/util v0.0.0-20221125031036-d86bb72019f6 h1:qABa0SjNgWhwcqM5RhB+PYaHxeYoiBfHcAkbkGnh1Nw=
2
+ github.com/aperturerobotics/util v0.0.0-20221125031036-d86bb72019f6/go.mod h1:XEAYwKz5dN1eqc7LFyoprXBD7AR/X3UC6o6NqSxnAG0=
1
3
  github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
2
4
  github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
3
5
  github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -28,8 +30,8 @@ github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgj
28
30
  github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
29
31
  github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
30
32
  github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
31
- github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
32
33
  github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
34
+ github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o=
33
35
  github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
34
36
  github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
35
37
  github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "starpc",
3
- "version": "0.14.0",
3
+ "version": "0.15.0",
4
4
  "description": "Streaming protobuf RPC service protocol over any two-way channel.",
5
5
  "license": "MIT",
6
6
  "author": {
@@ -2,34 +2,13 @@ package srpc
2
2
 
3
3
  import (
4
4
  "context"
5
- "io"
6
5
 
7
6
  "github.com/pkg/errors"
8
7
  )
9
8
 
10
9
  // ClientRPC represents the client side of an on-going RPC call message stream.
11
- // Not concurrency safe: use a mutex if calling concurrently.
12
10
  type ClientRPC struct {
13
- // ctx is the context, canceled when the rpc ends.
14
- ctx context.Context
15
- // ctxCancel is called when the rpc ends.
16
- ctxCancel context.CancelFunc
17
- // writer is the writer to write messages to
18
- writer Writer
19
- // service is the rpc service
20
- service string
21
- // method is the rpc method
22
- method string
23
- // dataCh contains queued data packets.
24
- // closed when the client closes the channel.
25
- dataCh chan []byte
26
- // dataChClosed is a flag set after dataCh is closed.
27
- // controlled by HandlePacket.
28
- dataChClosed bool
29
- // serverErr is an error set by the client.
30
- // before dataCh is closed, managed by HandlePacket.
31
- // immutable after dataCh is closed.
32
- serverErr error
11
+ commonRPC
33
12
  }
34
13
 
35
14
  // NewClientRPC constructs a new ClientRPC session and writes CallStart.
@@ -37,12 +16,10 @@ type ClientRPC struct {
37
16
  // service and method must be specified.
38
17
  // must call Start after creating the RPC object.
39
18
  func NewClientRPC(ctx context.Context, service, method string) *ClientRPC {
40
- rpc := &ClientRPC{
41
- service: service,
42
- method: method,
43
- dataCh: make(chan []byte, 5),
44
- }
45
- rpc.ctx, rpc.ctxCancel = context.WithCancel(ctx)
19
+ rpc := &ClientRPC{}
20
+ initCommonRPC(ctx, &rpc.commonRPC)
21
+ rpc.service = service
22
+ rpc.method = method
46
23
  return rpc
47
24
  }
48
25
 
@@ -51,73 +28,28 @@ func NewClientRPC(ctx context.Context, service, method string) *ClientRPC {
51
28
  func (r *ClientRPC) Start(writer Writer, writeFirstMsg bool, firstMsg []byte) error {
52
29
  select {
53
30
  case <-r.ctx.Done():
54
- r.Close()
31
+ r.ctxCancel()
55
32
  return context.Canceled
56
33
  default:
57
34
  }
35
+ r.mtx.Lock()
36
+ defer r.mtx.Unlock()
37
+ defer r.bcast.Broadcast()
58
38
  r.writer = writer
59
39
  var firstMsgEmpty bool
60
40
  if writeFirstMsg {
61
41
  firstMsgEmpty = len(firstMsg) == 0
62
- } else {
63
- firstMsg = nil
64
42
  }
65
43
  pkt := NewCallStartPacket(r.service, r.method, firstMsg, firstMsgEmpty)
66
44
  if err := writer.WritePacket(pkt); err != nil {
67
- r.Close()
45
+ r.ctxCancel()
46
+ _ = writer.Close()
68
47
  return err
69
48
  }
70
49
  return nil
71
50
  }
72
51
 
73
- // SendCancel sends the message notifying the peer we want to cancel the request.
74
- func (r *ClientRPC) SendCancel(writer Writer) error {
75
- pkt := NewCallCancelPacket()
76
- return writer.WritePacket(pkt)
77
- }
78
-
79
- // ReadAll reads all returned Data packets and returns any error.
80
- // intended for use with unary rpcs.
81
- func (r *ClientRPC) ReadAll() ([][]byte, error) {
82
- msgs := make([][]byte, 0, 1)
83
- for {
84
- select {
85
- case <-r.ctx.Done():
86
- return msgs, context.Canceled
87
- case data, ok := <-r.dataCh:
88
- if !ok {
89
- return msgs, r.serverErr
90
- }
91
- msgs = append(msgs, data)
92
- }
93
- }
94
- }
95
-
96
- // ReadOne reads a single message and returns.
97
- //
98
- // returns io.EOF if the stream ended.
99
- func (r *ClientRPC) ReadOne() ([]byte, error) {
100
- select {
101
- case <-r.ctx.Done():
102
- return nil, context.Canceled
103
- case data, ok := <-r.dataCh:
104
- if !ok {
105
- if err := r.serverErr; err != nil {
106
- return nil, err
107
- }
108
- return nil, io.EOF
109
- }
110
- return data, nil
111
- }
112
- }
113
-
114
- // Context is canceled when the ClientRPC is no longer valid.
115
- func (r *ClientRPC) Context() context.Context {
116
- return r.ctx
117
- }
118
-
119
52
  // HandlePacketData handles an incoming unparsed message packet.
120
- // Not concurrency safe: use a mutex if calling concurrently.
121
53
  func (r *ClientRPC) HandlePacketData(data []byte) error {
122
54
  pkt := &Packet{}
123
55
  if err := pkt.UnmarshalVT(data); err != nil {
@@ -128,16 +60,17 @@ func (r *ClientRPC) HandlePacketData(data []byte) error {
128
60
 
129
61
  // HandleStreamClose handles the incoming stream closing w/ optional error.
130
62
  func (r *ClientRPC) HandleStreamClose(closeErr error) {
131
- if closeErr != nil {
132
- if r.serverErr == nil {
133
- r.serverErr = closeErr
134
- }
135
- r.Close()
63
+ r.mtx.Lock()
64
+ defer r.mtx.Unlock()
65
+ defer r.bcast.Broadcast()
66
+ if closeErr != nil && r.remoteErr == nil {
67
+ r.remoteErr = closeErr
136
68
  }
69
+ r.dataClosed = true
70
+ r.ctxCancel()
137
71
  }
138
72
 
139
73
  // HandlePacket handles an incoming parsed message packet.
140
- // Not concurrency safe: use a mutex if calling concurrently.
141
74
  func (r *ClientRPC) HandlePacket(msg *Packet) error {
142
75
  if err := msg.Validate(); err != nil {
143
76
  return err
@@ -148,6 +81,11 @@ func (r *ClientRPC) HandlePacket(msg *Packet) error {
148
81
  return r.HandleCallStart(b.CallStart)
149
82
  case *Packet_CallData:
150
83
  return r.HandleCallData(b.CallData)
84
+ case *Packet_CallCancel:
85
+ if b.CallCancel {
86
+ return r.HandleCallCancel()
87
+ }
88
+ return nil
151
89
  default:
152
90
  return nil
153
91
  }
@@ -159,39 +97,14 @@ func (r *ClientRPC) HandleCallStart(pkt *CallStart) error {
159
97
  return errors.Wrap(ErrUnrecognizedPacket, "call start packet unexpected")
160
98
  }
161
99
 
162
- // HandleCallData handles the call data packet.
163
- func (r *ClientRPC) HandleCallData(pkt *CallData) error {
164
- if r.dataChClosed {
165
- return ErrCompleted
166
- }
167
-
168
- if data := pkt.GetData(); len(data) != 0 || pkt.GetDataIsZero() {
169
- select {
170
- case <-r.ctx.Done():
171
- return context.Canceled
172
- case r.dataCh <- data:
173
- }
174
- }
175
-
176
- complete := pkt.GetComplete()
177
- if err := pkt.GetError(); len(err) != 0 {
178
- complete = true
179
- r.serverErr = errors.New(err)
180
- }
181
-
182
- if complete {
183
- r.dataChClosed = true
184
- close(r.dataCh)
185
- }
186
-
187
- return nil
188
- }
189
-
190
100
  // Close releases any resources held by the ClientRPC.
191
- // not concurrency safe with HandlePacket.
192
101
  func (r *ClientRPC) Close() {
193
- r.ctxCancel()
102
+ r.mtx.Lock()
194
103
  if r.writer != nil {
104
+ _ = r.writeCancelLocked()
195
105
  _ = r.writer.Close()
196
106
  }
107
+ r.bcast.Broadcast()
108
+ r.mtx.Unlock()
109
+ r.ctxCancel()
197
110
  }
package/srpc/client.go CHANGED
@@ -38,34 +38,31 @@ func NewClient(openStream OpenStreamFunc) Client {
38
38
  }
39
39
 
40
40
  // ExecCall executes a request/reply RPC with the remote.
41
- func (c *client) ExecCall(rctx context.Context, service, method string, in, out Message) error {
42
- ctx, ctxCancel := context.WithCancel(rctx)
43
- defer ctxCancel()
44
-
41
+ func (c *client) ExecCall(ctx context.Context, service, method string, in, out Message) error {
45
42
  firstMsg, err := in.MarshalVT()
46
43
  if err != nil {
47
44
  return err
48
45
  }
46
+
49
47
  clientRPC := NewClientRPC(ctx, service, method)
48
+ defer clientRPC.Close()
49
+
50
50
  writer, err := c.openStream(ctx, clientRPC.HandlePacket, clientRPC.HandleStreamClose)
51
51
  if err != nil {
52
52
  return err
53
53
  }
54
- defer writer.Close()
55
54
  if err := clientRPC.Start(writer, true, firstMsg); err != nil {
56
55
  return err
57
56
  }
57
+
58
58
  msg, err := clientRPC.ReadOne()
59
59
  if err != nil {
60
- // send cancel message, if possible.
61
- _ = clientRPC.SendCancel(writer)
62
60
  // this includes any server returned error.
63
61
  return err
64
62
  }
65
63
  if err := out.UnmarshalVT(msg); err != nil {
66
64
  return errors.Wrap(ErrInvalidMessage, err.Error())
67
65
  }
68
- // done
69
66
  return nil
70
67
  }
71
68
 
@@ -90,7 +87,7 @@ func (c *client) NewStream(ctx context.Context, service, method string, firstMsg
90
87
  return nil, err
91
88
  }
92
89
 
93
- return NewMsgStream(ctx, clientRPC.writer, clientRPC.dataCh, clientRPC.ctxCancel), nil
90
+ return NewMsgStream(ctx, clientRPC, clientRPC.ctxCancel), nil
94
91
  }
95
92
 
96
93
  // _ is a type assertion
package/srpc/client.ts CHANGED
@@ -12,50 +12,20 @@ import {
12
12
  prependLengthPrefixTransform,
13
13
  } from './packet.js'
14
14
  import { combineUint8ArrayListTransform } from './array-list.js'
15
+ import { OpenStreamCtr } from './open-stream-ctr.js'
15
16
 
16
17
  // Client implements the ts-proto Rpc interface with the drpcproto protocol.
17
18
  export class Client implements TsProtoRpc {
18
- // openStreamFn is a promise which contains the OpenStreamFunc.
19
- private openStreamFn: Promise<OpenStreamFunc>
20
- // _openStreamFn resolves openStreamFn.
21
- private _openStreamFn?: (conn?: OpenStreamFunc, err?: Error) => void
19
+ // openStreamCtr contains the OpenStreamFunc.
20
+ private openStreamCtr: OpenStreamCtr
22
21
 
23
22
  constructor(openStreamFn?: OpenStreamFunc) {
24
- this.openStreamFn = this.setOpenStreamFn(openStreamFn)
23
+ this.openStreamCtr = new OpenStreamCtr(openStreamFn || undefined)
25
24
  }
26
25
 
27
26
  // setOpenStreamFn updates the openStreamFn for the Client.
28
- public setOpenStreamFn(
29
- openStreamFn?: OpenStreamFunc
30
- ): Promise<OpenStreamFunc> {
31
- if (this._openStreamFn) {
32
- if (openStreamFn) {
33
- this._openStreamFn(openStreamFn)
34
- this._openStreamFn = undefined
35
- }
36
- } else {
37
- if (openStreamFn) {
38
- this.openStreamFn = Promise.resolve(openStreamFn)
39
- } else {
40
- this.initOpenStreamFn()
41
- }
42
- }
43
- return this.openStreamFn
44
- }
45
-
46
- // initOpenStreamFn creates the empty Promise for openStreamFn.
47
- private initOpenStreamFn(): Promise<OpenStreamFunc> {
48
- const openPromise = new Promise<OpenStreamFunc>((resolve, reject) => {
49
- this._openStreamFn = (conn?: OpenStreamFunc, err?: Error) => {
50
- if (err) {
51
- reject(err)
52
- } else if (conn) {
53
- resolve(conn)
54
- }
55
- }
56
- })
57
- this.openStreamFn = openPromise
58
- return this.openStreamFn
27
+ public setOpenStreamFn(openStreamFn?: OpenStreamFunc) {
28
+ this.openStreamCtr.set(openStreamFn || undefined)
59
29
  }
60
30
 
61
31
  // request starts a non-streaming request.
@@ -137,7 +107,7 @@ export class Client implements TsProtoRpc {
137
107
  rpcMethod: string,
138
108
  data: Uint8Array | null
139
109
  ): Promise<ClientRPC> {
140
- const openStreamFn = await this.openStreamFn
110
+ const openStreamFn = await this.openStreamCtr.wait()
141
111
  const conn = await openStreamFn()
142
112
  const call = new ClientRPC(rpcService, rpcMethod)
143
113
  pipe(
@@ -0,0 +1,197 @@
1
+ package srpc
2
+
3
+ import (
4
+ "context"
5
+ "io"
6
+ "sync"
7
+
8
+ "github.com/aperturerobotics/util/broadcast"
9
+ "github.com/pkg/errors"
10
+ )
11
+
12
+ // commonRPC contains common logic between server/client rpc.
13
+ type commonRPC struct {
14
+ // ctx is the context, canceled when the rpc ends.
15
+ ctx context.Context
16
+ // ctxCancel is called when the rpc ends.
17
+ ctxCancel context.CancelFunc
18
+ // service is the rpc service
19
+ service string
20
+ // method is the rpc method
21
+ method string
22
+ // mtx guards below fields
23
+ mtx sync.Mutex
24
+ // bcast broadcasts when below fields change
25
+ bcast broadcast.Broadcast
26
+ // writer is the writer to write messages to
27
+ writer Writer
28
+ // dataQueue contains incoming data packets.
29
+ // note: packets may be len() == 0
30
+ dataQueue [][]byte
31
+ // dataClosed is a flag set after dataQueue is closed.
32
+ // controlled by HandlePacket.
33
+ dataClosed bool
34
+ // remoteErr is an error set by the remote.
35
+ remoteErr error
36
+ }
37
+
38
+ // initCommonRPC initializes the commonRPC.
39
+ func initCommonRPC(ctx context.Context, rpc *commonRPC) {
40
+ rpc.ctx, rpc.ctxCancel = context.WithCancel(ctx)
41
+ }
42
+
43
+ // Context is canceled when the rpc has finished.
44
+ func (c *commonRPC) Context() context.Context {
45
+ return c.ctx
46
+ }
47
+
48
+ // Wait waits for the RPC to finish.
49
+ func (c *commonRPC) Wait(ctx context.Context) error {
50
+ for {
51
+ c.mtx.Lock()
52
+ if c.dataClosed {
53
+ break
54
+ }
55
+ waiter := c.bcast.GetWaitCh()
56
+ c.mtx.Unlock()
57
+ select {
58
+ case <-ctx.Done():
59
+ return context.Canceled
60
+ case <-waiter:
61
+ }
62
+ }
63
+ err := c.remoteErr
64
+ c.mtx.Unlock()
65
+ return err
66
+ }
67
+
68
+ // ReadOne reads a single message and returns.
69
+ //
70
+ // returns io.EOF if the stream ended without a packet.
71
+ func (c *commonRPC) ReadOne() ([]byte, error) {
72
+ var msg []byte
73
+ var err error
74
+ var ctxDone bool
75
+ for {
76
+ c.mtx.Lock()
77
+ waiter := c.bcast.GetWaitCh()
78
+ if ctxDone && !c.dataClosed {
79
+ // context must have been canceled locally
80
+ c.closeLocked()
81
+ err = context.Canceled
82
+ c.mtx.Unlock()
83
+ break
84
+ }
85
+ if len(c.dataQueue) != 0 {
86
+ msg = c.dataQueue[0]
87
+ c.dataQueue[0] = nil
88
+ c.dataQueue = c.dataQueue[1:]
89
+ c.mtx.Unlock()
90
+ break
91
+ }
92
+ if c.dataClosed || c.remoteErr != nil {
93
+ err = c.remoteErr
94
+ if err == nil {
95
+ err = io.EOF
96
+ }
97
+ c.mtx.Unlock()
98
+ break
99
+ }
100
+ c.mtx.Unlock()
101
+ select {
102
+ case <-c.ctx.Done():
103
+ ctxDone = true
104
+ case <-waiter:
105
+ }
106
+ }
107
+ return msg, err
108
+ }
109
+
110
+ // WriteCallData writes a call data packet.
111
+ func (c *commonRPC) WriteCallData(data []byte, complete bool, err error) error {
112
+ c.mtx.Lock()
113
+ defer c.mtx.Unlock()
114
+ if c.writer == nil {
115
+ return ErrCompleted
116
+ }
117
+ outPkt := NewCallDataPacket(data, len(data) == 0, false, nil)
118
+ return c.writer.WritePacket(outPkt)
119
+ }
120
+
121
+ // HandleStreamClose handles the incoming stream closing w/ optional error.
122
+ func (c *commonRPC) HandleStreamClose(closeErr error) {
123
+ c.mtx.Lock()
124
+ defer c.mtx.Unlock()
125
+ if closeErr != nil && c.remoteErr == nil {
126
+ c.remoteErr = closeErr
127
+ }
128
+ c.dataClosed = true
129
+ c.ctxCancel()
130
+ if c.writer != nil {
131
+ _ = c.writer.Close()
132
+ }
133
+ c.bcast.Broadcast()
134
+ }
135
+
136
+ // HandleCallCancel handles the call cancel packet.
137
+ func (c *commonRPC) HandleCallCancel() error {
138
+ c.mtx.Lock()
139
+ defer c.mtx.Unlock()
140
+ if c.remoteErr != nil {
141
+ c.remoteErr = context.Canceled
142
+ }
143
+ c.dataClosed = true
144
+ if c.writer != nil {
145
+ _ = c.writer.Close()
146
+ }
147
+ c.bcast.Broadcast()
148
+ return nil
149
+ }
150
+
151
+ // HandleCallData handles the call data packet.
152
+ func (c *commonRPC) HandleCallData(pkt *CallData) error {
153
+ c.mtx.Lock()
154
+ defer c.mtx.Unlock()
155
+
156
+ if c.dataClosed {
157
+ return ErrCompleted
158
+ }
159
+
160
+ if data := pkt.GetData(); len(data) != 0 || pkt.GetDataIsZero() {
161
+ c.dataQueue = append(c.dataQueue, data)
162
+ }
163
+
164
+ complete := pkt.GetComplete()
165
+ if err := pkt.GetError(); len(err) != 0 {
166
+ complete = true
167
+ c.remoteErr = errors.New(err)
168
+ }
169
+
170
+ if complete {
171
+ c.dataClosed = true
172
+ }
173
+
174
+ c.bcast.Broadcast()
175
+ return nil
176
+ }
177
+
178
+ // writeCancelLocked writes a call cancel packet.
179
+ func (c *commonRPC) writeCancelLocked() error {
180
+ if c.writer != nil {
181
+ return c.writer.WritePacket(NewCallCancelPacket())
182
+ }
183
+ return nil
184
+ }
185
+
186
+ // closeLocked releases resources held by the RPC.
187
+ func (c *commonRPC) closeLocked() {
188
+ c.dataClosed = true
189
+ if c.remoteErr == nil {
190
+ c.remoteErr = context.Canceled
191
+ }
192
+ if c.writer != nil {
193
+ _ = c.writeCancelLocked()
194
+ _ = c.writer.Close()
195
+ }
196
+ c.ctxCancel()
197
+ }
package/srpc/index.ts CHANGED
@@ -15,4 +15,6 @@ export {
15
15
  newMessagePortIterable,
16
16
  MessagePortConn,
17
17
  } from './message-port.js'
18
+ export { ValueCtr } from './value-ctr.js'
19
+ export { OpenStreamCtr } from './open-stream-ctr.js'
18
20
  export { writeToPushable } from './pushable'
@@ -2,36 +2,38 @@ package srpc
2
2
 
3
3
  import (
4
4
  "context"
5
- "io"
6
- "sync/atomic"
7
5
  )
8
6
 
7
+ // MsgStreamRw is the read-write interface for MsgStream.
8
+ type MsgStreamRw interface {
9
+ // ReadOne reads a single message and returns.
10
+ //
11
+ // returns io.EOF if the stream ended.
12
+ ReadOne() ([]byte, error)
13
+ // WriteCallData writes a call data packet.
14
+ WriteCallData(data []byte, complete bool, err error) error
15
+ }
16
+
9
17
  // MsgStream implements the stream interface passed to implementations.
10
18
  type MsgStream struct {
11
19
  // ctx is the stream context
12
20
  ctx context.Context
13
- // writer is the stream writer
14
- writer Writer
15
- // dataCh is the incoming data channel.
16
- dataCh chan []byte
21
+ // rw is the msg stream read-writer
22
+ rw MsgStreamRw
17
23
  // closeCb is the close callback
18
24
  closeCb func()
19
- // sendClosed indicates sending is closed.
20
- sendClosed atomic.Bool
21
25
  }
22
26
 
23
27
  // NewMsgStream constructs a new Stream with a ClientRPC.
24
28
  // dataCh should be closed when no more messages will arrive.
25
29
  func NewMsgStream(
26
30
  ctx context.Context,
27
- writer Writer,
28
- dataCh chan []byte,
31
+ rw MsgStreamRw,
29
32
  closeCb func(),
30
33
  ) *MsgStream {
31
34
  return &MsgStream{
32
35
  ctx: ctx,
33
- writer: writer,
34
- dataCh: dataCh,
36
+ rw: rw,
35
37
  closeCb: closeCb,
36
38
  }
37
39
  }
@@ -43,9 +45,6 @@ func (r *MsgStream) Context() context.Context {
43
45
 
44
46
  // MsgSend sends the message to the remote.
45
47
  func (r *MsgStream) MsgSend(msg Message) error {
46
- if r.sendClosed.Load() {
47
- return context.Canceled
48
- }
49
48
  select {
50
49
  case <-r.ctx.Done():
51
50
  return context.Canceled
@@ -56,40 +55,27 @@ func (r *MsgStream) MsgSend(msg Message) error {
56
55
  if err != nil {
57
56
  return err
58
57
  }
59
- outPkt := NewCallDataPacket(msgData, len(msgData) == 0, false, nil)
60
- return r.writer.WritePacket(outPkt)
58
+ return r.rw.WriteCallData(msgData, false, nil)
61
59
  }
62
60
 
63
61
  // MsgRecv receives an incoming message from the remote.
64
62
  // Parses the message into the object at msg.
65
63
  func (r *MsgStream) MsgRecv(msg Message) error {
66
- select {
67
- case <-r.Context().Done():
68
- return context.Canceled
69
- case data, ok := <-r.dataCh:
70
- if !ok {
71
- return io.EOF
72
- }
73
- return msg.UnmarshalVT(data)
64
+ data, err := r.rw.ReadOne()
65
+ if err != nil {
66
+ return err
74
67
  }
68
+ return msg.UnmarshalVT(data)
75
69
  }
76
70
 
77
71
  // CloseSend signals to the remote that we will no longer send any messages.
78
72
  func (r *MsgStream) CloseSend() error {
79
- if !r.sendClosed.Swap(true) {
80
- return r.closeSend()
81
- }
82
- return nil
73
+ return r.rw.WriteCallData(nil, true, nil)
83
74
  }
84
75
 
85
76
  // Close closes the stream.
86
77
  func (r *MsgStream) Close() error {
87
- if !r.sendClosed.Swap(true) {
88
- if err := r.closeSend(); err != nil {
89
- return err
90
- }
91
- }
92
-
78
+ _ = r.CloseSend()
93
79
  if r.closeCb != nil {
94
80
  r.closeCb()
95
81
  }
@@ -97,11 +83,5 @@ func (r *MsgStream) Close() error {
97
83
  return nil
98
84
  }
99
85
 
100
- // closeSend writes the CloseSend packet.
101
- func (r *MsgStream) closeSend() error {
102
- outPkt := NewCallDataPacket(nil, false, true, nil)
103
- return r.writer.WritePacket(outPkt)
104
- }
105
-
106
86
  // _ is a type assertion
107
87
  var _ Stream = ((*MsgStream)(nil))
@@ -15,6 +15,7 @@ func NewYamuxConfig() *yamux.Config {
15
15
  // Configuration options from go-libp2p-yamux:
16
16
  config := *ymuxer.DefaultTransport.Config()
17
17
  config.AcceptBacklog = 512
18
+ config.EnableKeepAlive = false
18
19
  return &config
19
20
  }
20
21
 
@@ -0,0 +1,20 @@
1
+ import { OpenStreamFunc, Stream } from './stream.js'
2
+ import { ValueCtr } from './value-ctr.js'
3
+
4
+ // OpenStreamCtr contains an OpenStream func which can be awaited.
5
+ export class OpenStreamCtr extends ValueCtr<OpenStreamFunc> {
6
+ constructor(openStreamFn?: OpenStreamFunc) {
7
+ super(openStreamFn)
8
+ }
9
+
10
+ // openStreamFunc returns an OpenStreamFunc which waits for the underlying OpenStreamFunc.
11
+ get openStreamFunc(): OpenStreamFunc {
12
+ return async (): Promise<Stream> => {
13
+ let openFn = this.value
14
+ if (!openFn) {
15
+ openFn = await this.wait()
16
+ }
17
+ return openFn()
18
+ }
19
+ }
20
+ }
package/srpc/packet-rw.go CHANGED
@@ -94,7 +94,7 @@ func (r *PacketReaderWriter) ReadToHandler(cb PacketHandler) error {
94
94
 
95
95
  // parse the length prefix if not done already
96
96
  if currLen == 0 {
97
- currLen = r.readLengthPrefix(r.buf.Bytes())
97
+ currLen = r.readLengthPrefix(r.buf.Bytes()[:4])
98
98
  if currLen == 0 {
99
99
  return errors.New("unexpected zero len prefix")
100
100
  }
@@ -55,11 +55,6 @@ func (s *HTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
55
55
  }
56
56
  return
57
57
  }
58
- go func() {
59
- err := s.srpc.HandleStream(ctx, strm)
60
- _ = err
61
- // TODO: handle / log error?
62
- // err != nil && err != io.EOF && err != context.Canceled
63
- }()
58
+ go s.srpc.HandleStream(ctx, strm)
64
59
  }
65
60
  }
@@ -11,9 +11,7 @@ import (
11
11
  func NewServerPipe(server *Server) OpenStreamFunc {
12
12
  return func(ctx context.Context, msgHandler PacketHandler, closeHandler CloseHandler) (Writer, error) {
13
13
  srvPipe, clientPipe := net.Pipe()
14
- go func() {
15
- _ = server.HandleStream(ctx, srvPipe)
16
- }()
14
+ go server.HandleStream(ctx, srvPipe)
17
15
  clientPrw := NewPacketReadWriter(clientPipe)
18
16
  go clientPrw.ReadPump(msgHandler, closeHandler)
19
17
  return clientPrw, nil
@@ -2,88 +2,27 @@ package srpc
2
2
 
3
3
  import (
4
4
  "context"
5
- "io"
6
5
 
7
6
  "github.com/pkg/errors"
8
7
  )
9
8
 
10
9
  // ServerRPC represents the server side of an on-going RPC call message stream.
11
- // Not concurrency safe: use a mutex if calling concurrently.
12
10
  type ServerRPC struct {
13
- // ctx is the context, canceled when the rpc ends.
14
- ctx context.Context
15
- // ctxCancel is called when the rpc ends.
16
- ctxCancel context.CancelFunc
17
- // writer is the writer to write messages to
18
- writer Writer
11
+ commonRPC
19
12
  // invoker is the rpc call invoker
20
13
  invoker Invoker
21
- // service is the rpc service
22
- service string
23
- // method is the rpc method
24
- method string
25
- // dataCh contains queued data packets.
26
- // closed when the client closes the channel.
27
- dataCh chan []byte
28
- // dataChClosed is a flag set after dataCh is closed.
29
- // controlled by HandlePacket.
30
- dataChClosed bool
31
- // clientErr is an error set by the client.
32
- // before dataCh is closed, managed by HandlePacket.
33
- // immutable after dataCh is closed or ctxCancel
34
- clientErr error
35
14
  }
36
15
 
37
16
  // NewServerRPC constructs a new ServerRPC session.
38
17
  // note: call SetWriter before handling any incoming messages.
39
- func NewServerRPC(ctx context.Context, invoker Invoker) *ServerRPC {
40
- rpc := &ServerRPC{
41
- dataCh: make(chan []byte, 5),
42
- invoker: invoker,
43
- }
44
- rpc.ctx, rpc.ctxCancel = context.WithCancel(ctx)
18
+ func NewServerRPC(ctx context.Context, invoker Invoker, writer Writer) *ServerRPC {
19
+ rpc := &ServerRPC{invoker: invoker}
20
+ initCommonRPC(ctx, &rpc.commonRPC)
21
+ rpc.writer = writer
45
22
  return rpc
46
23
  }
47
24
 
48
- // SetWriter sets the writer field.
49
- func (r *ServerRPC) SetWriter(w Writer) {
50
- r.writer = w
51
- }
52
-
53
- // Context is canceled when the ServerRPC is no longer valid.
54
- func (r *ServerRPC) Context() context.Context {
55
- return r.ctx
56
- }
57
-
58
- // Wait waits for the RPC to finish.
59
- func (r *ServerRPC) Wait(ctx context.Context) error {
60
- select {
61
- case <-ctx.Done():
62
- return context.Canceled
63
- case <-r.ctx.Done():
64
- }
65
- return r.clientErr
66
- }
67
-
68
- // HandleStreamClose handles the incoming stream closing w/ optional error.
69
- func (r *ServerRPC) HandleStreamClose(closeErr error) {
70
- if r.dataChClosed {
71
- return
72
- }
73
- if closeErr != nil {
74
- if r.clientErr == nil {
75
- r.clientErr = closeErr
76
- }
77
- if closeErr != io.EOF && closeErr != context.Canceled {
78
- r.Close()
79
- }
80
- }
81
- r.dataChClosed = true
82
- close(r.dataCh)
83
- }
84
-
85
25
  // HandlePacket handles an incoming parsed message packet.
86
- // Not concurrency safe: use a mutex if calling concurrently.
87
26
  func (r *ServerRPC) HandlePacket(msg *Packet) error {
88
27
  if msg == nil {
89
28
  return nil
@@ -99,7 +38,7 @@ func (r *ServerRPC) HandlePacket(msg *Packet) error {
99
38
  return r.HandleCallData(b.CallData)
100
39
  case *Packet_CallCancel:
101
40
  if b.CallCancel {
102
- r.Close()
41
+ return r.HandleCallCancel()
103
42
  }
104
43
  return nil
105
44
  default:
@@ -109,86 +48,40 @@ func (r *ServerRPC) HandlePacket(msg *Packet) error {
109
48
 
110
49
  // HandleCallStart handles the call start packet.
111
50
  func (r *ServerRPC) HandleCallStart(pkt *CallStart) error {
51
+ r.mtx.Lock()
52
+ defer r.mtx.Unlock()
112
53
  // process start: method and service
113
54
  if r.method != "" || r.service != "" {
114
55
  return errors.New("call start must be sent only once")
115
56
  }
116
- if r.dataChClosed {
57
+ if r.dataClosed {
117
58
  return ErrCompleted
118
59
  }
119
- r.method, r.service = pkt.GetRpcMethod(), pkt.GetRpcService()
60
+ service, method := pkt.GetRpcService(), pkt.GetRpcMethod()
61
+ r.service, r.method = service, method
120
62
 
121
63
  // process first data packet, if included
122
64
  if data := pkt.GetData(); len(data) != 0 || pkt.GetDataIsZero() {
123
- if data == nil {
124
- data = []byte{}
125
- }
126
- select {
127
- case r.dataCh <- data:
128
- default:
129
- // the channel should be empty w/ a buffer capacity of 5 here.
130
- return errors.New("data channel was full, expected empty")
131
- }
65
+ r.dataQueue = append(r.dataQueue, data)
132
66
  }
133
67
 
134
68
  // invoke the rpc
135
- go r.invokeRPC()
136
-
137
- return nil
138
- }
139
-
140
- // HandleCallData handles the call data packet.
141
- func (r *ServerRPC) HandleCallData(pkt *CallData) error {
142
- if r.dataChClosed {
143
- return ErrCompleted
144
- }
145
-
146
- if data := pkt.GetData(); len(data) != 0 || pkt.GetDataIsZero() {
147
- select {
148
- case <-r.ctx.Done():
149
- return context.Canceled
150
- case r.dataCh <- data:
151
- }
152
- }
153
-
154
- complete := pkt.GetComplete()
155
- if err := pkt.GetError(); len(err) != 0 {
156
- complete = true
157
- r.clientErr = errors.New(err)
158
- }
159
-
160
- if complete {
161
- r.dataChClosed = true
162
- close(r.dataCh)
163
- }
164
-
69
+ r.bcast.Broadcast()
70
+ go r.invokeRPC(service, method)
165
71
  return nil
166
72
  }
167
73
 
168
- // invoke invokes the RPC after CallStart is received.
169
- func (r *ServerRPC) invokeRPC() {
74
+ // invokeRPC invokes the RPC after CallStart is received.
75
+ func (r *ServerRPC) invokeRPC(serviceID, methodID string) {
170
76
  // ctx := r.ctx
171
- serviceID, methodID := r.service, r.method
172
- strm := NewMsgStream(r.ctx, r.writer, r.dataCh, r.ctxCancel)
77
+ strm := NewMsgStream(r.ctx, r, r.ctxCancel)
173
78
  ok, err := r.invoker.InvokeMethod(serviceID, methodID, strm)
174
79
  if err == nil && !ok {
175
80
  err = ErrUnimplemented
176
81
  }
82
+ // TODO: close dataCh here?
177
83
  outPkt := NewCallDataPacket(nil, false, true, err)
178
84
  _ = r.writer.WritePacket(outPkt)
179
85
  _ = r.writer.Close()
180
86
  r.ctxCancel()
181
87
  }
182
-
183
- // Close releases any resources held by the ServerRPC.
184
- // not concurrency safe with HandlePacket.
185
- func (r *ServerRPC) Close() {
186
- if r.clientErr == nil {
187
- r.clientErr = context.Canceled
188
- }
189
- if r.writer != nil && r.service == "" {
190
- // invokeRPC has not been called, otherwise it would call Close()
191
- _ = r.writer.Close()
192
- }
193
- r.ctxCancel()
194
- }
package/srpc/server.go CHANGED
@@ -25,15 +25,13 @@ func (s *Server) GetInvoker() Invoker {
25
25
  return s.invoker
26
26
  }
27
27
 
28
- // HandleStream handles an incoming ReadWriteCloser stream.
29
- func (s *Server) HandleStream(ctx context.Context, rwc io.ReadWriteCloser) error {
28
+ // HandleStream handles an incoming stream and runs the read loop.
29
+ func (s *Server) HandleStream(ctx context.Context, rwc io.ReadWriteCloser) {
30
30
  subCtx, subCtxCancel := context.WithCancel(ctx)
31
31
  defer subCtxCancel()
32
- serverRPC := NewServerRPC(subCtx, s.invoker)
33
32
  prw := NewPacketReadWriter(rwc)
34
- serverRPC.SetWriter(prw)
35
- go prw.ReadPump(serverRPC.HandlePacket, serverRPC.HandleStreamClose)
36
- return serverRPC.Wait(ctx)
33
+ serverRPC := NewServerRPC(subCtx, s.invoker, prw)
34
+ prw.ReadPump(serverRPC.HandlePacket, serverRPC.HandleStreamClose)
37
35
  }
38
36
 
39
37
  // AcceptMuxedConn runs a loop which calls Accept on a muxer to handle streams.
@@ -55,8 +53,6 @@ func (s *Server) AcceptMuxedConn(ctx context.Context, mc network.MuxedConn) erro
55
53
  if err != nil {
56
54
  return err
57
55
  }
58
- go func() {
59
- _ = s.HandleStream(ctx, muxedStream)
60
- }()
56
+ go s.HandleStream(ctx, muxedStream)
61
57
  }
62
58
  }
@@ -0,0 +1,54 @@
1
+ // ValueCtr contains a value that can be set asynchronously.
2
+ export class ValueCtr<T> {
3
+ // _value contains the current value.
4
+ private _value: T | undefined
5
+ // _waiters contains the list of waiters.
6
+ // called when the value is set to any value other than undefined.
7
+ private _waiters: ((fn: T) => void)[]
8
+
9
+ constructor(initialValue?: T) {
10
+ this._value = initialValue || undefined
11
+ this._waiters = []
12
+ }
13
+
14
+ // value returns the current value.
15
+ get value(): T | undefined {
16
+ return this._value
17
+ }
18
+
19
+ // wait waits for the value to not be undefined.
20
+ public async wait(): Promise<T> {
21
+ const currVal = this._value
22
+ if (currVal !== undefined) {
23
+ return currVal
24
+ }
25
+ return new Promise<T>((resolve) => {
26
+ this.waitWithCb((val: T) => {
27
+ resolve(val)
28
+ })
29
+ })
30
+ }
31
+
32
+ // waitWithCb adds a callback to be called when the value is not undefined.
33
+ public waitWithCb(cb: (val: T) => void) {
34
+ if (cb) {
35
+ this._waiters.push(cb)
36
+ }
37
+ }
38
+
39
+ // set sets the value and calls the callbacks.
40
+ public set(val: T | undefined) {
41
+ this._value = val
42
+ if (val === undefined) {
43
+ return
44
+ }
45
+ const waiters = this._waiters
46
+ if (waiters.length === 0) {
47
+ return
48
+ }
49
+ this._waiters = []
50
+ for (const waiter of waiters) {
51
+ waiter(val)
52
+ }
53
+ }
54
+ }
package/srpc/websocket.go CHANGED
@@ -2,65 +2,20 @@ package srpc
2
2
 
3
3
  import (
4
4
  "context"
5
- "io"
6
5
 
7
6
  "github.com/libp2p/go-libp2p/core/network"
8
7
  "github.com/libp2p/go-yamux/v4"
9
8
  "nhooyr.io/websocket"
10
9
  )
11
10
 
12
- // WebSocketConn implements the p2p multiplexer over a WebSocket.
13
- type WebSocketConn struct {
14
- // conn is the websocket conn
15
- conn *websocket.Conn
16
- // mconn is the muxed conn
17
- mconn network.MuxedConn
18
- }
19
-
20
- // NewWebSocketConn constructs a new WebSocket connection.
21
- //
11
+ // NewWebSocketConn wraps a websocket into a MuxedConn.
22
12
  // if yamuxConf is unset, uses the defaults.
23
- func NewWebSocketConn(ctx context.Context, conn *websocket.Conn, isServer bool, yamuxConf *yamux.Config) (*WebSocketConn, error) {
13
+ func NewWebSocketConn(
14
+ ctx context.Context,
15
+ conn *websocket.Conn,
16
+ isServer bool,
17
+ yamuxConf *yamux.Config,
18
+ ) (network.MuxedConn, error) {
24
19
  nc := websocket.NetConn(ctx, conn, websocket.MessageBinary)
25
- muxedConn, err := NewMuxedConn(nc, !isServer, yamuxConf)
26
- if err != nil {
27
- return nil, err
28
- }
29
- return &WebSocketConn{conn: conn, mconn: muxedConn}, nil
30
- }
31
-
32
- // GetWebSocket returns the web socket conn.
33
- func (w *WebSocketConn) GetWebSocket() *websocket.Conn {
34
- return w.conn
35
- }
36
-
37
- // GetOpenStreamFunc returns the OpenStream func.
38
- func (w *WebSocketConn) GetOpenStreamFunc() OpenStreamFunc {
39
- return w.OpenStream
40
- }
41
-
42
- // AcceptStream accepts an incoming stream.
43
- func (w *WebSocketConn) AcceptStream() (io.ReadWriteCloser, error) {
44
- strm, err := w.mconn.AcceptStream()
45
- if err != nil {
46
- return nil, err
47
- }
48
- return strm, nil
49
- }
50
-
51
- // OpenStream tries to open a stream with the remote.
52
- func (w *WebSocketConn) OpenStream(ctx context.Context, msgHandler PacketHandler, closeHandler CloseHandler) (Writer, error) {
53
- muxedStream, err := w.mconn.OpenStream(ctx)
54
- if err != nil {
55
- return nil, err
56
- }
57
-
58
- rw := NewPacketReadWriter(muxedStream)
59
- go rw.ReadPump(msgHandler, closeHandler)
60
- return rw, nil
61
- }
62
-
63
- // Close closes the writer.
64
- func (w *WebSocketConn) Close() error {
65
- return w.conn.Close(websocket.StatusGoingAway, "conn closed")
20
+ return NewMuxedConn(nc, !isServer, yamuxConf)
66
21
  }