starpc 0.49.15 → 0.49.17
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/srpc/common-rpc.d.ts +8 -1
- package/dist/srpc/common-rpc.js +31 -4
- package/dist/srpc/common-rpc.test.d.ts +1 -0
- package/dist/srpc/common-rpc.test.js +178 -0
- package/package.json +9 -3
- package/srpc/client-rpc.go +3 -3
- package/srpc/client.go +1 -1
- package/srpc/common-rpc.go +56 -12
- package/srpc/common-rpc.test.ts +256 -0
- package/srpc/common-rpc.ts +60 -14
- package/srpc/common-rpc_test.go +242 -0
- package/srpc/rwc-conn.go +3 -1
- package/srpc/server-rpc.go +3 -3
- package/srpc/server.go +3 -1
- package/srpc/stream-pipe.go +4 -2
|
@@ -10,12 +10,14 @@ export declare class CommonRPC {
|
|
|
10
10
|
protected service?: string;
|
|
11
11
|
protected method?: string;
|
|
12
12
|
private closed?;
|
|
13
|
+
private readonly writeDrainAbort;
|
|
13
14
|
constructor();
|
|
14
15
|
get isClosed(): boolean | Error;
|
|
15
16
|
writeCallData(data?: Uint8Array, complete?: boolean, error?: string): Promise<void>;
|
|
17
|
+
private writeCallDataPacket;
|
|
16
18
|
writeCallCancel(): Promise<void>;
|
|
17
19
|
writeCallDataFromSource(dataSource: AsyncIterable<Uint8Array>): Promise<void>;
|
|
18
|
-
protected writePacket(packet: Packet): Promise<void>;
|
|
20
|
+
protected writePacket(packet: Packet, options?: WritePacketOptions): Promise<void>;
|
|
19
21
|
handleMessage(message: Uint8Array): Promise<void>;
|
|
20
22
|
handlePacket(packet: Packet): Promise<void>;
|
|
21
23
|
handleCallStart(packet: Partial<CallStart>): Promise<void>;
|
|
@@ -25,3 +27,8 @@ export declare class CommonRPC {
|
|
|
25
27
|
close(err?: Error): Promise<void>;
|
|
26
28
|
private _createSink;
|
|
27
29
|
}
|
|
30
|
+
interface WritePacketOptions {
|
|
31
|
+
allowClosed?: boolean;
|
|
32
|
+
waitForDrain?: boolean;
|
|
33
|
+
}
|
|
34
|
+
export {};
|
package/dist/srpc/common-rpc.js
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import { pushable } from 'it-pushable';
|
|
2
2
|
import { Packet } from './rpcproto.pb.js';
|
|
3
3
|
import { ERR_RPC_ABORT, RemoteRPCError } from './errors.js';
|
|
4
|
+
const maxBufferedOutgoingPackets = 1;
|
|
4
5
|
// CommonRPC is common logic between server and client RPCs.
|
|
5
6
|
export class CommonRPC {
|
|
6
7
|
// sink is the data sink for incoming messages.
|
|
@@ -23,6 +24,8 @@ export class CommonRPC {
|
|
|
23
24
|
method;
|
|
24
25
|
// closed indicates this rpc has been closed already.
|
|
25
26
|
closed;
|
|
27
|
+
// writeDrainAbort wakes writers waiting for outbound stream drain on close.
|
|
28
|
+
writeDrainAbort = new AbortController();
|
|
26
29
|
constructor() {
|
|
27
30
|
this.sink = this._createSink();
|
|
28
31
|
this.source = this._source;
|
|
@@ -34,6 +37,10 @@ export class CommonRPC {
|
|
|
34
37
|
}
|
|
35
38
|
// writeCallData writes the call data packet.
|
|
36
39
|
async writeCallData(data, complete, error) {
|
|
40
|
+
await this.writeCallDataPacket(data, complete, error);
|
|
41
|
+
}
|
|
42
|
+
// writeCallDataPacket writes a call-data packet with optional drain control.
|
|
43
|
+
async writeCallDataPacket(data, complete, error, writeOptions) {
|
|
37
44
|
const callData = {
|
|
38
45
|
data: data || new Uint8Array(0),
|
|
39
46
|
dataIsZero: !!data && data.length === 0,
|
|
@@ -45,7 +52,7 @@ export class CommonRPC {
|
|
|
45
52
|
case: 'callData',
|
|
46
53
|
value: callData,
|
|
47
54
|
},
|
|
48
|
-
});
|
|
55
|
+
}, writeOptions);
|
|
49
56
|
}
|
|
50
57
|
// writeCallCancel writes the call cancel packet.
|
|
51
58
|
async writeCallCancel() {
|
|
@@ -54,7 +61,7 @@ export class CommonRPC {
|
|
|
54
61
|
case: 'callCancel',
|
|
55
62
|
value: true,
|
|
56
63
|
},
|
|
57
|
-
});
|
|
64
|
+
}, { waitForDrain: false });
|
|
58
65
|
}
|
|
59
66
|
// writeCallDataFromSource writes all call data from the iterable.
|
|
60
67
|
async writeCallDataFromSource(dataSource) {
|
|
@@ -69,8 +76,24 @@ export class CommonRPC {
|
|
|
69
76
|
}
|
|
70
77
|
}
|
|
71
78
|
// writePacket writes a packet to the stream.
|
|
72
|
-
async writePacket(packet) {
|
|
79
|
+
async writePacket(packet, options) {
|
|
80
|
+
if (this.closed && !options?.allowClosed) {
|
|
81
|
+
throw new Error(ERR_RPC_ABORT);
|
|
82
|
+
}
|
|
73
83
|
this._source.push(packet);
|
|
84
|
+
if (options?.waitForDrain === false ||
|
|
85
|
+
this._source.readableLength <= maxBufferedOutgoingPackets) {
|
|
86
|
+
return;
|
|
87
|
+
}
|
|
88
|
+
try {
|
|
89
|
+
await this._source.onEmpty({ signal: this.writeDrainAbort.signal });
|
|
90
|
+
}
|
|
91
|
+
catch (err) {
|
|
92
|
+
if (this.closed) {
|
|
93
|
+
throw new Error(ERR_RPC_ABORT, { cause: err });
|
|
94
|
+
}
|
|
95
|
+
throw err;
|
|
96
|
+
}
|
|
74
97
|
}
|
|
75
98
|
// handleMessage handles an incoming encoded Packet.
|
|
76
99
|
//
|
|
@@ -149,8 +172,12 @@ export class CommonRPC {
|
|
|
149
172
|
this.closed = err ?? true;
|
|
150
173
|
// note: this does nothing if _source is already ended.
|
|
151
174
|
if (err && err.message) {
|
|
152
|
-
await this.
|
|
175
|
+
await this.writeCallDataPacket(undefined, true, err.message, {
|
|
176
|
+
allowClosed: true,
|
|
177
|
+
waitForDrain: false,
|
|
178
|
+
});
|
|
153
179
|
}
|
|
180
|
+
this.writeDrainAbort.abort();
|
|
154
181
|
this._source.end();
|
|
155
182
|
this._rpcDataSource.end(err);
|
|
156
183
|
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export {};
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
import { describe, expect, it } from 'vitest';
|
|
2
|
+
import { Client } from './client.js';
|
|
3
|
+
import { CommonRPC } from './common-rpc.js';
|
|
4
|
+
import { Packet } from './rpcproto.pb.js';
|
|
5
|
+
describe('CommonRPC', () => {
|
|
6
|
+
it('backpressures call-data sources until outbound packets drain', async () => {
|
|
7
|
+
const rpc = new CommonRPC();
|
|
8
|
+
const yielded = { count: 0 };
|
|
9
|
+
const secondYield = deferred();
|
|
10
|
+
const thirdYield = deferred();
|
|
11
|
+
const writeDone = rpc.writeCallDataFromSource(chunkSource(yielded, new Map([
|
|
12
|
+
[1, secondYield],
|
|
13
|
+
[2, thirdYield],
|
|
14
|
+
])));
|
|
15
|
+
await promiseWithTimeout(secondYield.promise, 'second call-data chunk');
|
|
16
|
+
expect(yielded.count).toBe(2);
|
|
17
|
+
await expectPending(thirdYield.promise, 'third call-data chunk');
|
|
18
|
+
expect(yielded.count).toBe(2);
|
|
19
|
+
const received = new Array();
|
|
20
|
+
const readComplete = (async () => {
|
|
21
|
+
for await (const packet of rpc.source) {
|
|
22
|
+
const body = packet.body;
|
|
23
|
+
if (body?.case !== 'callData') {
|
|
24
|
+
continue;
|
|
25
|
+
}
|
|
26
|
+
if (body.value.complete) {
|
|
27
|
+
return;
|
|
28
|
+
}
|
|
29
|
+
if (!body.value.data) {
|
|
30
|
+
throw new Error('call data packet missing data');
|
|
31
|
+
}
|
|
32
|
+
received.push(body.value.data[0]);
|
|
33
|
+
}
|
|
34
|
+
throw new Error('outbound call data ended before completion');
|
|
35
|
+
})();
|
|
36
|
+
await promiseWithTimeout(Promise.all([writeDone, readComplete]), 'backpressured call data drain');
|
|
37
|
+
await rpc.close();
|
|
38
|
+
expect(yielded.count).toBe(32);
|
|
39
|
+
expect(received).toHaveLength(32);
|
|
40
|
+
expect(received[0]).toBe(0);
|
|
41
|
+
expect(received[31]).toBe(31);
|
|
42
|
+
});
|
|
43
|
+
it('wakes call-data writers when the rpc closes while waiting for drain', async () => {
|
|
44
|
+
const rpc = new CommonRPC();
|
|
45
|
+
const yielded = { count: 0 };
|
|
46
|
+
const secondYield = deferred();
|
|
47
|
+
const thirdYield = deferred();
|
|
48
|
+
const writeDone = rpc.writeCallDataFromSource(chunkSource(yielded, new Map([
|
|
49
|
+
[1, secondYield],
|
|
50
|
+
[2, thirdYield],
|
|
51
|
+
])));
|
|
52
|
+
await promiseWithTimeout(secondYield.promise, 'second call-data chunk');
|
|
53
|
+
expect(yielded.count).toBe(2);
|
|
54
|
+
await expectPending(thirdYield.promise, 'third call-data chunk');
|
|
55
|
+
expect(yielded.count).toBe(2);
|
|
56
|
+
await rpc.close();
|
|
57
|
+
await expect(promiseWithTimeout(writeDone, 'call data writer close')).resolves.toBeUndefined();
|
|
58
|
+
expect(yielded.count).toBeLessThan(32);
|
|
59
|
+
});
|
|
60
|
+
it('wakes call-data writers when an error closes the rpc while waiting for drain', async () => {
|
|
61
|
+
const rpc = new CommonRPC();
|
|
62
|
+
const yielded = { count: 0 };
|
|
63
|
+
const secondYield = deferred();
|
|
64
|
+
const thirdYield = deferred();
|
|
65
|
+
const writeDone = rpc.writeCallDataFromSource(chunkSource(yielded, new Map([
|
|
66
|
+
[1, secondYield],
|
|
67
|
+
[2, thirdYield],
|
|
68
|
+
])));
|
|
69
|
+
await promiseWithTimeout(secondYield.promise, 'second call-data chunk');
|
|
70
|
+
expect(yielded.count).toBe(2);
|
|
71
|
+
await expectPending(thirdYield.promise, 'third call-data chunk');
|
|
72
|
+
expect(yielded.count).toBe(2);
|
|
73
|
+
await rpc.close(new Error('boom'));
|
|
74
|
+
await expect(promiseWithTimeout(writeDone, 'call data writer error close')).resolves.toBeUndefined();
|
|
75
|
+
expect(yielded.count).toBeLessThan(32);
|
|
76
|
+
});
|
|
77
|
+
it('does not backpressure call-cancel packets behind queued call data', async () => {
|
|
78
|
+
const rpc = new CommonRPC();
|
|
79
|
+
const yielded = { count: 0 };
|
|
80
|
+
const secondYield = deferred();
|
|
81
|
+
const thirdYield = deferred();
|
|
82
|
+
const writeDone = rpc.writeCallDataFromSource(chunkSource(yielded, new Map([
|
|
83
|
+
[1, secondYield],
|
|
84
|
+
[2, thirdYield],
|
|
85
|
+
])));
|
|
86
|
+
await promiseWithTimeout(secondYield.promise, 'second call-data chunk');
|
|
87
|
+
expect(yielded.count).toBe(2);
|
|
88
|
+
await expectPending(thirdYield.promise, 'third call-data chunk');
|
|
89
|
+
expect(yielded.count).toBe(2);
|
|
90
|
+
const cancelDone = rpc.writeCallCancel();
|
|
91
|
+
await rpc.close(new Error('abort'));
|
|
92
|
+
await expect(promiseWithTimeout(cancelDone, 'call cancel write')).resolves.toBeUndefined();
|
|
93
|
+
await expect(promiseWithTimeout(writeDone, 'blocked call data close')).resolves.toBeUndefined();
|
|
94
|
+
});
|
|
95
|
+
it('backpressures client-stream sources through the Client request path', async () => {
|
|
96
|
+
const yielded = { count: 0 };
|
|
97
|
+
const firstYield = deferred();
|
|
98
|
+
const secondYield = deferred();
|
|
99
|
+
const sinkConsumed = { count: 0 };
|
|
100
|
+
const sinkGate = deferred();
|
|
101
|
+
const responseGate = deferred();
|
|
102
|
+
const response = new Uint8Array([7]);
|
|
103
|
+
const client = new Client(async () => ({
|
|
104
|
+
source: (async function* () {
|
|
105
|
+
await responseGate.promise;
|
|
106
|
+
yield Packet.toBinary({
|
|
107
|
+
body: {
|
|
108
|
+
case: 'callData',
|
|
109
|
+
value: {
|
|
110
|
+
data: response,
|
|
111
|
+
dataIsZero: false,
|
|
112
|
+
complete: false,
|
|
113
|
+
error: '',
|
|
114
|
+
},
|
|
115
|
+
},
|
|
116
|
+
});
|
|
117
|
+
})(),
|
|
118
|
+
sink: async (source) => {
|
|
119
|
+
await sinkGate.promise;
|
|
120
|
+
for await (const _packet of source) {
|
|
121
|
+
sinkConsumed.count++;
|
|
122
|
+
}
|
|
123
|
+
},
|
|
124
|
+
}));
|
|
125
|
+
const request = client.clientStreamingRequest('test.Service', 'Upload', chunkSource(yielded, new Map([
|
|
126
|
+
[0, firstYield],
|
|
127
|
+
[1, secondYield],
|
|
128
|
+
])));
|
|
129
|
+
await promiseWithTimeout(firstYield.promise, 'first upload chunk');
|
|
130
|
+
expect(yielded.count).toBe(1);
|
|
131
|
+
await expectPending(secondYield.promise, 'second upload chunk');
|
|
132
|
+
expect(yielded.count).toBe(1);
|
|
133
|
+
expect(sinkConsumed.count).toBe(0);
|
|
134
|
+
await expectPending(request, 'client-stream request');
|
|
135
|
+
sinkGate.resolve();
|
|
136
|
+
responseGate.resolve();
|
|
137
|
+
await expect(promiseWithTimeout(request, 'client-stream response')).resolves.toEqual(response);
|
|
138
|
+
});
|
|
139
|
+
});
|
|
140
|
+
async function* chunkSource(yielded, yieldMarks) {
|
|
141
|
+
for (const i of Array.from({ length: 32 }, (_, index) => index)) {
|
|
142
|
+
yielded.count++;
|
|
143
|
+
yieldMarks?.get(i)?.resolve();
|
|
144
|
+
yield new Uint8Array([i]);
|
|
145
|
+
}
|
|
146
|
+
}
|
|
147
|
+
function promiseWithTimeout(promise, label) {
|
|
148
|
+
return Promise.race([
|
|
149
|
+
promise,
|
|
150
|
+
new Promise((_, reject) => {
|
|
151
|
+
setTimeout(() => reject(new Error(`timed out waiting for ${label}`)), 500);
|
|
152
|
+
}),
|
|
153
|
+
]);
|
|
154
|
+
}
|
|
155
|
+
async function expectPending(promise, label) {
|
|
156
|
+
await expect(Promise.race([
|
|
157
|
+
promise.then(() => 'settled'),
|
|
158
|
+
new Promise((resolve) => {
|
|
159
|
+
setTimeout(() => resolve('pending'), 50);
|
|
160
|
+
}),
|
|
161
|
+
]), `${label} should not be requested while outbound packets are blocked`).resolves.toBe('pending');
|
|
162
|
+
}
|
|
163
|
+
function deferred() {
|
|
164
|
+
const callbacks = {};
|
|
165
|
+
const promise = new Promise((resolve, reject) => {
|
|
166
|
+
callbacks.resolve = resolve;
|
|
167
|
+
callbacks.reject = reject;
|
|
168
|
+
});
|
|
169
|
+
return {
|
|
170
|
+
promise,
|
|
171
|
+
resolve() {
|
|
172
|
+
callbacks.resolve?.();
|
|
173
|
+
},
|
|
174
|
+
reject(reason) {
|
|
175
|
+
callbacks.reject?.(reason);
|
|
176
|
+
},
|
|
177
|
+
};
|
|
178
|
+
}
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "starpc",
|
|
3
|
-
"version": "0.49.
|
|
3
|
+
"version": "0.49.17",
|
|
4
4
|
"description": "Streaming protobuf RPC service protocol over any two-way channel.",
|
|
5
5
|
"license": "MIT",
|
|
6
6
|
"author": {
|
|
@@ -128,9 +128,9 @@
|
|
|
128
128
|
"vitest": "^4.1.5"
|
|
129
129
|
},
|
|
130
130
|
"dependencies": {
|
|
131
|
-
"@aptre/yamux": "^1.0.3",
|
|
132
131
|
"@aptre/it-ws": "^1.1.2",
|
|
133
|
-
"@aptre/protobuf-es-lite": "
|
|
132
|
+
"@aptre/protobuf-es-lite": "1.1.0",
|
|
133
|
+
"@aptre/yamux": "^1.0.3",
|
|
134
134
|
"event-iterator": "^2.0.0",
|
|
135
135
|
"isomorphic-ws": "^5.0.0",
|
|
136
136
|
"it-first": "^3.0.11",
|
|
@@ -139,5 +139,11 @@
|
|
|
139
139
|
"it-stream-types": "^2.0.4",
|
|
140
140
|
"uint8arraylist": "^3.0.0",
|
|
141
141
|
"ws": "^8.20.0"
|
|
142
|
+
},
|
|
143
|
+
"resolutions": {
|
|
144
|
+
"@aptre/protobuf-es-lite": "1.1.0"
|
|
145
|
+
},
|
|
146
|
+
"overrides": {
|
|
147
|
+
"@aptre/protobuf-es-lite": "1.1.0"
|
|
142
148
|
}
|
|
143
149
|
}
|
package/srpc/client-rpc.go
CHANGED
|
@@ -31,7 +31,7 @@ func (r *ClientRPC) Start(writer PacketWriter, writeFirstMsg bool, firstMsg []by
|
|
|
31
31
|
}
|
|
32
32
|
|
|
33
33
|
if err := r.ctx.Err(); err != nil {
|
|
34
|
-
r.
|
|
34
|
+
r.cancelContext()
|
|
35
35
|
_ = writer.Close()
|
|
36
36
|
return context.Canceled
|
|
37
37
|
}
|
|
@@ -48,7 +48,7 @@ func (r *ClientRPC) Start(writer PacketWriter, writeFirstMsg bool, firstMsg []by
|
|
|
48
48
|
pkt := NewCallStartPacket(r.service, r.method, firstMsg, firstMsgEmpty)
|
|
49
49
|
err = writer.WritePacket(pkt)
|
|
50
50
|
if err != nil {
|
|
51
|
-
r.
|
|
51
|
+
r.cancelContext()
|
|
52
52
|
_ = writer.Close()
|
|
53
53
|
}
|
|
54
54
|
|
|
@@ -74,7 +74,7 @@ func (r *ClientRPC) HandleStreamClose(closeErr error) {
|
|
|
74
74
|
r.remoteErr = closeErr
|
|
75
75
|
}
|
|
76
76
|
r.dataClosed = true
|
|
77
|
-
r.
|
|
77
|
+
r.cancelContext()
|
|
78
78
|
locked.Broadcast()
|
|
79
79
|
locked.Unlock()
|
|
80
80
|
}
|
package/srpc/client.go
CHANGED
package/srpc/common-rpc.go
CHANGED
|
@@ -5,16 +5,19 @@ import (
|
|
|
5
5
|
"io"
|
|
6
6
|
"sync/atomic"
|
|
7
7
|
|
|
8
|
+
"github.com/aperturerobotics/starpc/internal/contextutil"
|
|
8
9
|
"github.com/aperturerobotics/util/broadcast"
|
|
9
10
|
"github.com/pkg/errors"
|
|
10
11
|
)
|
|
11
12
|
|
|
12
13
|
// commonRPC contains common logic between server/client rpc.
|
|
13
14
|
type commonRPC struct {
|
|
14
|
-
// ctx is the context, canceled when the
|
|
15
|
+
// ctx is the RPC context, canceled when the RPC is canceled.
|
|
15
16
|
ctx context.Context
|
|
16
|
-
// ctxCancel
|
|
17
|
+
// ctxCancel cancels ctx.
|
|
17
18
|
ctxCancel context.CancelFunc
|
|
19
|
+
// ctxCanceled tracks whether ctxCancel has already been called.
|
|
20
|
+
ctxCanceled atomic.Bool
|
|
18
21
|
// service is the rpc service
|
|
19
22
|
service string
|
|
20
23
|
// method is the rpc method
|
|
@@ -28,6 +31,11 @@ type commonRPC struct {
|
|
|
28
31
|
writer PacketWriter
|
|
29
32
|
// writerClosed is set after writer has been closed locally.
|
|
30
33
|
writerClosed bool
|
|
34
|
+
// localCompleting is set while the local handler is publishing its terminal
|
|
35
|
+
// packet and closing the writer.
|
|
36
|
+
localCompleting bool
|
|
37
|
+
// localDone is set after the local handler has completed normally.
|
|
38
|
+
localDone bool
|
|
31
39
|
// dataQueue contains incoming data packets.
|
|
32
40
|
// note: packets may be len() == 0
|
|
33
41
|
dataQueue [][]byte
|
|
@@ -40,7 +48,14 @@ type commonRPC struct {
|
|
|
40
48
|
|
|
41
49
|
// initCommonRPC initializes the commonRPC.
|
|
42
50
|
func initCommonRPC(ctx context.Context, rpc *commonRPC) {
|
|
43
|
-
rpc.ctx, rpc.ctxCancel =
|
|
51
|
+
rpc.ctx, rpc.ctxCancel = contextutil.WithCancel(ctx)
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
func (c *commonRPC) cancelContext() {
|
|
55
|
+
if c.ctxCanceled.Swap(true) {
|
|
56
|
+
return
|
|
57
|
+
}
|
|
58
|
+
c.ctxCancel()
|
|
44
59
|
}
|
|
45
60
|
|
|
46
61
|
// Context is canceled when the rpc has finished.
|
|
@@ -53,10 +68,13 @@ func (c *commonRPC) Wait(ctx context.Context) error {
|
|
|
53
68
|
for {
|
|
54
69
|
var err error
|
|
55
70
|
var waitCh <-chan struct{}
|
|
56
|
-
var
|
|
71
|
+
var rpcCanceled bool
|
|
72
|
+
var localDone bool
|
|
57
73
|
locked := c.bcast.Lock()
|
|
58
|
-
|
|
59
|
-
|
|
74
|
+
err = c.remoteErr
|
|
75
|
+
rpcCanceled = c.ctx.Err() != nil
|
|
76
|
+
localDone = c.localDone
|
|
77
|
+
if err == nil && !rpcCanceled && !localDone {
|
|
60
78
|
waitCh = locked.WaitCh()
|
|
61
79
|
}
|
|
62
80
|
locked.Unlock()
|
|
@@ -64,15 +82,16 @@ func (c *commonRPC) Wait(ctx context.Context) error {
|
|
|
64
82
|
if err != nil {
|
|
65
83
|
return err
|
|
66
84
|
}
|
|
67
|
-
if
|
|
68
|
-
|
|
85
|
+
if localDone {
|
|
86
|
+
return nil
|
|
87
|
+
}
|
|
88
|
+
if rpcCanceled {
|
|
69
89
|
return context.Canceled
|
|
70
90
|
}
|
|
71
91
|
|
|
72
92
|
select {
|
|
73
93
|
case <-ctx.Done():
|
|
74
94
|
return context.Canceled
|
|
75
|
-
case <-rpcCtx.Done():
|
|
76
95
|
case <-waitCh:
|
|
77
96
|
}
|
|
78
97
|
}
|
|
@@ -153,12 +172,15 @@ func (c *commonRPC) HandleStreamClose(closeErr error) {
|
|
|
153
172
|
locked.Unlock()
|
|
154
173
|
return
|
|
155
174
|
}
|
|
175
|
+
normalRemoteCloseAfterLocalComplete := closeErr == nil && (c.localCompleting || c.localDone)
|
|
156
176
|
if closeErr != nil && c.remoteErr == nil {
|
|
157
177
|
c.remoteErr = closeErr
|
|
158
178
|
}
|
|
159
179
|
c.dataClosed = true
|
|
160
|
-
|
|
161
|
-
|
|
180
|
+
if !normalRemoteCloseAfterLocalComplete {
|
|
181
|
+
c.cancelContext()
|
|
182
|
+
writer = c.closeWriterLocked()
|
|
183
|
+
}
|
|
162
184
|
locked.Broadcast()
|
|
163
185
|
locked.Unlock()
|
|
164
186
|
if writer != nil {
|
|
@@ -225,7 +247,7 @@ func (c *commonRPC) closeLocked(locked *broadcast.Locked) PacketWriter {
|
|
|
225
247
|
}
|
|
226
248
|
writer := c.closeWriterLocked()
|
|
227
249
|
locked.Broadcast()
|
|
228
|
-
c.
|
|
250
|
+
c.cancelContext()
|
|
229
251
|
return writer
|
|
230
252
|
}
|
|
231
253
|
|
|
@@ -236,3 +258,25 @@ func (c *commonRPC) closeWriterLocked() PacketWriter {
|
|
|
236
258
|
c.writerClosed = true
|
|
237
259
|
return c.writer
|
|
238
260
|
}
|
|
261
|
+
|
|
262
|
+
func (c *commonRPC) beginLocalCompletion() {
|
|
263
|
+
locked := c.bcast.Lock()
|
|
264
|
+
c.localCompleted.Store(true)
|
|
265
|
+
c.localCompleting = true
|
|
266
|
+
locked.Unlock()
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
func (c *commonRPC) finishLocalCompletion() {
|
|
270
|
+
locked := c.bcast.Lock()
|
|
271
|
+
c.localCompleted.Store(true)
|
|
272
|
+
writer := c.closeWriterLocked()
|
|
273
|
+
locked.Unlock()
|
|
274
|
+
if writer != nil {
|
|
275
|
+
_ = writer.Close()
|
|
276
|
+
}
|
|
277
|
+
locked = c.bcast.Lock()
|
|
278
|
+
c.localCompleting = false
|
|
279
|
+
c.localDone = true
|
|
280
|
+
locked.Broadcast()
|
|
281
|
+
locked.Unlock()
|
|
282
|
+
}
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
import { describe, expect, it } from 'vitest'
|
|
2
|
+
|
|
3
|
+
import { Client } from './client.js'
|
|
4
|
+
import { CommonRPC } from './common-rpc.js'
|
|
5
|
+
import { Packet } from './rpcproto.pb.js'
|
|
6
|
+
|
|
7
|
+
describe('CommonRPC', () => {
|
|
8
|
+
it('backpressures call-data sources until outbound packets drain', async () => {
|
|
9
|
+
const rpc = new CommonRPC()
|
|
10
|
+
const yielded = { count: 0 }
|
|
11
|
+
const secondYield = deferred()
|
|
12
|
+
const thirdYield = deferred()
|
|
13
|
+
|
|
14
|
+
const writeDone = rpc.writeCallDataFromSource(
|
|
15
|
+
chunkSource(
|
|
16
|
+
yielded,
|
|
17
|
+
new Map([
|
|
18
|
+
[1, secondYield],
|
|
19
|
+
[2, thirdYield],
|
|
20
|
+
]),
|
|
21
|
+
),
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
await promiseWithTimeout(secondYield.promise, 'second call-data chunk')
|
|
25
|
+
expect(yielded.count).toBe(2)
|
|
26
|
+
await expectPending(thirdYield.promise, 'third call-data chunk')
|
|
27
|
+
expect(yielded.count).toBe(2)
|
|
28
|
+
|
|
29
|
+
const received = new Array<number>()
|
|
30
|
+
const readComplete = (async () => {
|
|
31
|
+
for await (const packet of rpc.source) {
|
|
32
|
+
const body = packet.body
|
|
33
|
+
if (body?.case !== 'callData') {
|
|
34
|
+
continue
|
|
35
|
+
}
|
|
36
|
+
if (body.value.complete) {
|
|
37
|
+
return
|
|
38
|
+
}
|
|
39
|
+
if (!body.value.data) {
|
|
40
|
+
throw new Error('call data packet missing data')
|
|
41
|
+
}
|
|
42
|
+
received.push(body.value.data[0])
|
|
43
|
+
}
|
|
44
|
+
throw new Error('outbound call data ended before completion')
|
|
45
|
+
})()
|
|
46
|
+
|
|
47
|
+
await promiseWithTimeout(
|
|
48
|
+
Promise.all([writeDone, readComplete]),
|
|
49
|
+
'backpressured call data drain',
|
|
50
|
+
)
|
|
51
|
+
await rpc.close()
|
|
52
|
+
|
|
53
|
+
expect(yielded.count).toBe(32)
|
|
54
|
+
expect(received).toHaveLength(32)
|
|
55
|
+
expect(received[0]).toBe(0)
|
|
56
|
+
expect(received[31]).toBe(31)
|
|
57
|
+
})
|
|
58
|
+
|
|
59
|
+
it('wakes call-data writers when the rpc closes while waiting for drain', async () => {
|
|
60
|
+
const rpc = new CommonRPC()
|
|
61
|
+
const yielded = { count: 0 }
|
|
62
|
+
const secondYield = deferred()
|
|
63
|
+
const thirdYield = deferred()
|
|
64
|
+
|
|
65
|
+
const writeDone = rpc.writeCallDataFromSource(
|
|
66
|
+
chunkSource(
|
|
67
|
+
yielded,
|
|
68
|
+
new Map([
|
|
69
|
+
[1, secondYield],
|
|
70
|
+
[2, thirdYield],
|
|
71
|
+
]),
|
|
72
|
+
),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
await promiseWithTimeout(secondYield.promise, 'second call-data chunk')
|
|
76
|
+
expect(yielded.count).toBe(2)
|
|
77
|
+
await expectPending(thirdYield.promise, 'third call-data chunk')
|
|
78
|
+
expect(yielded.count).toBe(2)
|
|
79
|
+
|
|
80
|
+
await rpc.close()
|
|
81
|
+
await expect(
|
|
82
|
+
promiseWithTimeout(writeDone, 'call data writer close'),
|
|
83
|
+
).resolves.toBeUndefined()
|
|
84
|
+
expect(yielded.count).toBeLessThan(32)
|
|
85
|
+
})
|
|
86
|
+
|
|
87
|
+
it('wakes call-data writers when an error closes the rpc while waiting for drain', async () => {
|
|
88
|
+
const rpc = new CommonRPC()
|
|
89
|
+
const yielded = { count: 0 }
|
|
90
|
+
const secondYield = deferred()
|
|
91
|
+
const thirdYield = deferred()
|
|
92
|
+
|
|
93
|
+
const writeDone = rpc.writeCallDataFromSource(
|
|
94
|
+
chunkSource(
|
|
95
|
+
yielded,
|
|
96
|
+
new Map([
|
|
97
|
+
[1, secondYield],
|
|
98
|
+
[2, thirdYield],
|
|
99
|
+
]),
|
|
100
|
+
),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
await promiseWithTimeout(secondYield.promise, 'second call-data chunk')
|
|
104
|
+
expect(yielded.count).toBe(2)
|
|
105
|
+
await expectPending(thirdYield.promise, 'third call-data chunk')
|
|
106
|
+
expect(yielded.count).toBe(2)
|
|
107
|
+
|
|
108
|
+
await rpc.close(new Error('boom'))
|
|
109
|
+
await expect(
|
|
110
|
+
promiseWithTimeout(writeDone, 'call data writer error close'),
|
|
111
|
+
).resolves.toBeUndefined()
|
|
112
|
+
expect(yielded.count).toBeLessThan(32)
|
|
113
|
+
})
|
|
114
|
+
|
|
115
|
+
it('does not backpressure call-cancel packets behind queued call data', async () => {
|
|
116
|
+
const rpc = new CommonRPC()
|
|
117
|
+
const yielded = { count: 0 }
|
|
118
|
+
const secondYield = deferred()
|
|
119
|
+
const thirdYield = deferred()
|
|
120
|
+
const writeDone = rpc.writeCallDataFromSource(
|
|
121
|
+
chunkSource(
|
|
122
|
+
yielded,
|
|
123
|
+
new Map([
|
|
124
|
+
[1, secondYield],
|
|
125
|
+
[2, thirdYield],
|
|
126
|
+
]),
|
|
127
|
+
),
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
await promiseWithTimeout(secondYield.promise, 'second call-data chunk')
|
|
131
|
+
expect(yielded.count).toBe(2)
|
|
132
|
+
await expectPending(thirdYield.promise, 'third call-data chunk')
|
|
133
|
+
expect(yielded.count).toBe(2)
|
|
134
|
+
|
|
135
|
+
const cancelDone = rpc.writeCallCancel()
|
|
136
|
+
await rpc.close(new Error('abort'))
|
|
137
|
+
|
|
138
|
+
await expect(
|
|
139
|
+
promiseWithTimeout(cancelDone, 'call cancel write'),
|
|
140
|
+
).resolves.toBeUndefined()
|
|
141
|
+
await expect(
|
|
142
|
+
promiseWithTimeout(writeDone, 'blocked call data close'),
|
|
143
|
+
).resolves.toBeUndefined()
|
|
144
|
+
})
|
|
145
|
+
|
|
146
|
+
it('backpressures client-stream sources through the Client request path', async () => {
|
|
147
|
+
const yielded = { count: 0 }
|
|
148
|
+
const firstYield = deferred()
|
|
149
|
+
const secondYield = deferred()
|
|
150
|
+
const sinkConsumed = { count: 0 }
|
|
151
|
+
const sinkGate = deferred()
|
|
152
|
+
const responseGate = deferred()
|
|
153
|
+
const response = new Uint8Array([7])
|
|
154
|
+
const client = new Client(async () => ({
|
|
155
|
+
source: (async function* () {
|
|
156
|
+
await responseGate.promise
|
|
157
|
+
yield Packet.toBinary({
|
|
158
|
+
body: {
|
|
159
|
+
case: 'callData',
|
|
160
|
+
value: {
|
|
161
|
+
data: response,
|
|
162
|
+
dataIsZero: false,
|
|
163
|
+
complete: false,
|
|
164
|
+
error: '',
|
|
165
|
+
},
|
|
166
|
+
},
|
|
167
|
+
})
|
|
168
|
+
})(),
|
|
169
|
+
sink: async (source) => {
|
|
170
|
+
await sinkGate.promise
|
|
171
|
+
for await (const _packet of source) {
|
|
172
|
+
sinkConsumed.count++
|
|
173
|
+
}
|
|
174
|
+
},
|
|
175
|
+
}))
|
|
176
|
+
|
|
177
|
+
const request = client.clientStreamingRequest(
|
|
178
|
+
'test.Service',
|
|
179
|
+
'Upload',
|
|
180
|
+
chunkSource(
|
|
181
|
+
yielded,
|
|
182
|
+
new Map([
|
|
183
|
+
[0, firstYield],
|
|
184
|
+
[1, secondYield],
|
|
185
|
+
]),
|
|
186
|
+
),
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
await promiseWithTimeout(firstYield.promise, 'first upload chunk')
|
|
190
|
+
expect(yielded.count).toBe(1)
|
|
191
|
+
await expectPending(secondYield.promise, 'second upload chunk')
|
|
192
|
+
expect(yielded.count).toBe(1)
|
|
193
|
+
expect(sinkConsumed.count).toBe(0)
|
|
194
|
+
await expectPending(request, 'client-stream request')
|
|
195
|
+
|
|
196
|
+
sinkGate.resolve()
|
|
197
|
+
responseGate.resolve()
|
|
198
|
+
await expect(
|
|
199
|
+
promiseWithTimeout(request, 'client-stream response'),
|
|
200
|
+
).resolves.toEqual(response)
|
|
201
|
+
})
|
|
202
|
+
})
|
|
203
|
+
|
|
204
|
+
async function* chunkSource(
|
|
205
|
+
yielded: { count: number },
|
|
206
|
+
yieldMarks?: Map<number, Deferred>,
|
|
207
|
+
) {
|
|
208
|
+
for (const i of Array.from({ length: 32 }, (_, index) => index)) {
|
|
209
|
+
yielded.count++
|
|
210
|
+
yieldMarks?.get(i)?.resolve()
|
|
211
|
+
yield new Uint8Array([i])
|
|
212
|
+
}
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
function promiseWithTimeout<T>(promise: Promise<T>, label: string): Promise<T> {
|
|
216
|
+
return Promise.race([
|
|
217
|
+
promise,
|
|
218
|
+
new Promise<T>((_, reject) => {
|
|
219
|
+
setTimeout(() => reject(new Error(`timed out waiting for ${label}`)), 500)
|
|
220
|
+
}),
|
|
221
|
+
])
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
async function expectPending<T>(promise: Promise<T>, label: string) {
|
|
225
|
+
await expect(
|
|
226
|
+
Promise.race([
|
|
227
|
+
promise.then(() => 'settled'),
|
|
228
|
+
new Promise<'pending'>((resolve) => {
|
|
229
|
+
setTimeout(() => resolve('pending'), 50)
|
|
230
|
+
}),
|
|
231
|
+
]),
|
|
232
|
+
`${label} should not be requested while outbound packets are blocked`,
|
|
233
|
+
).resolves.toBe('pending')
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
type Deferred = ReturnType<typeof deferred>
|
|
237
|
+
|
|
238
|
+
function deferred() {
|
|
239
|
+
const callbacks: {
|
|
240
|
+
resolve?: () => void
|
|
241
|
+
reject?: (reason?: unknown) => void
|
|
242
|
+
} = {}
|
|
243
|
+
const promise = new Promise<void>((resolve, reject) => {
|
|
244
|
+
callbacks.resolve = resolve
|
|
245
|
+
callbacks.reject = reject
|
|
246
|
+
})
|
|
247
|
+
return {
|
|
248
|
+
promise,
|
|
249
|
+
resolve() {
|
|
250
|
+
callbacks.resolve?.()
|
|
251
|
+
},
|
|
252
|
+
reject(reason?: unknown) {
|
|
253
|
+
callbacks.reject?.(reason)
|
|
254
|
+
},
|
|
255
|
+
}
|
|
256
|
+
}
|
package/srpc/common-rpc.ts
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
import type { Sink, Source } from 'it-stream-types'
|
|
2
|
-
import { pushable } from 'it-pushable'
|
|
2
|
+
import { pushable, type Pushable } from 'it-pushable'
|
|
3
3
|
import { CompleteMessage } from '@aptre/protobuf-es-lite'
|
|
4
4
|
|
|
5
5
|
import type { CallData, CallStart } from './rpcproto.pb.js'
|
|
6
6
|
import { Packet } from './rpcproto.pb.js'
|
|
7
7
|
import { ERR_RPC_ABORT, RemoteRPCError } from './errors.js'
|
|
8
8
|
|
|
9
|
+
const maxBufferedOutgoingPackets = 1
|
|
10
|
+
|
|
9
11
|
// CommonRPC is common logic between server and client RPCs.
|
|
10
12
|
export class CommonRPC {
|
|
11
13
|
// sink is the data sink for incoming messages.
|
|
@@ -16,7 +18,7 @@ export class CommonRPC {
|
|
|
16
18
|
public readonly rpcDataSource: AsyncIterable<Uint8Array>
|
|
17
19
|
|
|
18
20
|
// _source is used to write to the source.
|
|
19
|
-
private readonly _source = pushable<Packet>({
|
|
21
|
+
private readonly _source: Pushable<Packet> = pushable<Packet>({
|
|
20
22
|
objectMode: true,
|
|
21
23
|
})
|
|
22
24
|
|
|
@@ -32,6 +34,8 @@ export class CommonRPC {
|
|
|
32
34
|
|
|
33
35
|
// closed indicates this rpc has been closed already.
|
|
34
36
|
private closed?: true | Error
|
|
37
|
+
// writeDrainAbort wakes writers waiting for outbound stream drain on close.
|
|
38
|
+
private readonly writeDrainAbort = new AbortController()
|
|
35
39
|
|
|
36
40
|
constructor() {
|
|
37
41
|
this.sink = this._createSink()
|
|
@@ -49,6 +53,16 @@ export class CommonRPC {
|
|
|
49
53
|
data?: Uint8Array,
|
|
50
54
|
complete?: boolean,
|
|
51
55
|
error?: string,
|
|
56
|
+
) {
|
|
57
|
+
await this.writeCallDataPacket(data, complete, error)
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
// writeCallDataPacket writes a call-data packet with optional drain control.
|
|
61
|
+
private async writeCallDataPacket(
|
|
62
|
+
data?: Uint8Array,
|
|
63
|
+
complete?: boolean,
|
|
64
|
+
error?: string,
|
|
65
|
+
writeOptions?: WritePacketOptions,
|
|
52
66
|
) {
|
|
53
67
|
const callData: CompleteMessage<CallData> = {
|
|
54
68
|
data: data || new Uint8Array(0),
|
|
@@ -56,22 +70,28 @@ export class CommonRPC {
|
|
|
56
70
|
complete: complete || false,
|
|
57
71
|
error: error || '',
|
|
58
72
|
}
|
|
59
|
-
await this.writePacket(
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
73
|
+
await this.writePacket(
|
|
74
|
+
{
|
|
75
|
+
body: {
|
|
76
|
+
case: 'callData',
|
|
77
|
+
value: callData,
|
|
78
|
+
},
|
|
63
79
|
},
|
|
64
|
-
|
|
80
|
+
writeOptions,
|
|
81
|
+
)
|
|
65
82
|
}
|
|
66
83
|
|
|
67
84
|
// writeCallCancel writes the call cancel packet.
|
|
68
85
|
public async writeCallCancel() {
|
|
69
|
-
await this.writePacket(
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
86
|
+
await this.writePacket(
|
|
87
|
+
{
|
|
88
|
+
body: {
|
|
89
|
+
case: 'callCancel',
|
|
90
|
+
value: true,
|
|
91
|
+
},
|
|
73
92
|
},
|
|
74
|
-
|
|
93
|
+
{ waitForDrain: false },
|
|
94
|
+
)
|
|
75
95
|
}
|
|
76
96
|
|
|
77
97
|
// writeCallDataFromSource writes all call data from the iterable.
|
|
@@ -87,8 +107,25 @@ export class CommonRPC {
|
|
|
87
107
|
}
|
|
88
108
|
|
|
89
109
|
// writePacket writes a packet to the stream.
|
|
90
|
-
protected async writePacket(packet: Packet) {
|
|
110
|
+
protected async writePacket(packet: Packet, options?: WritePacketOptions) {
|
|
111
|
+
if (this.closed && !options?.allowClosed) {
|
|
112
|
+
throw new Error(ERR_RPC_ABORT)
|
|
113
|
+
}
|
|
91
114
|
this._source.push(packet)
|
|
115
|
+
if (
|
|
116
|
+
options?.waitForDrain === false ||
|
|
117
|
+
this._source.readableLength <= maxBufferedOutgoingPackets
|
|
118
|
+
) {
|
|
119
|
+
return
|
|
120
|
+
}
|
|
121
|
+
try {
|
|
122
|
+
await this._source.onEmpty({ signal: this.writeDrainAbort.signal })
|
|
123
|
+
} catch (err) {
|
|
124
|
+
if (this.closed) {
|
|
125
|
+
throw new Error(ERR_RPC_ABORT, { cause: err })
|
|
126
|
+
}
|
|
127
|
+
throw err
|
|
128
|
+
}
|
|
92
129
|
}
|
|
93
130
|
|
|
94
131
|
// handleMessage handles an incoming encoded Packet.
|
|
@@ -179,8 +216,12 @@ export class CommonRPC {
|
|
|
179
216
|
this.closed = err ?? true
|
|
180
217
|
// note: this does nothing if _source is already ended.
|
|
181
218
|
if (err && err.message) {
|
|
182
|
-
await this.
|
|
219
|
+
await this.writeCallDataPacket(undefined, true, err.message, {
|
|
220
|
+
allowClosed: true,
|
|
221
|
+
waitForDrain: false,
|
|
222
|
+
})
|
|
183
223
|
}
|
|
224
|
+
this.writeDrainAbort.abort()
|
|
184
225
|
this._source.end()
|
|
185
226
|
this._rpcDataSource.end(err)
|
|
186
227
|
}
|
|
@@ -206,3 +247,8 @@ export class CommonRPC {
|
|
|
206
247
|
}
|
|
207
248
|
}
|
|
208
249
|
}
|
|
250
|
+
|
|
251
|
+
interface WritePacketOptions {
|
|
252
|
+
allowClosed?: boolean
|
|
253
|
+
waitForDrain?: boolean
|
|
254
|
+
}
|
package/srpc/common-rpc_test.go
CHANGED
|
@@ -4,8 +4,10 @@ import (
|
|
|
4
4
|
"bytes"
|
|
5
5
|
"context"
|
|
6
6
|
"io"
|
|
7
|
+
"sync"
|
|
7
8
|
"sync/atomic"
|
|
8
9
|
"testing"
|
|
10
|
+
"time"
|
|
9
11
|
)
|
|
10
12
|
|
|
11
13
|
type closeCountingPacketWriter struct {
|
|
@@ -16,6 +18,19 @@ type closeCallbackPacketWriter struct {
|
|
|
16
18
|
closeFn func()
|
|
17
19
|
}
|
|
18
20
|
|
|
21
|
+
type packetRecordingWriter struct {
|
|
22
|
+
packets chan *Packet
|
|
23
|
+
closed chan struct{}
|
|
24
|
+
once sync.Once
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
type blockingPacketWriter struct {
|
|
28
|
+
packets chan *Packet
|
|
29
|
+
releaseWrite chan struct{}
|
|
30
|
+
closed chan struct{}
|
|
31
|
+
once sync.Once
|
|
32
|
+
}
|
|
33
|
+
|
|
19
34
|
func (w *closeCountingPacketWriter) WritePacket(*Packet) error {
|
|
20
35
|
return nil
|
|
21
36
|
}
|
|
@@ -36,6 +51,46 @@ func (w *closeCallbackPacketWriter) Close() error {
|
|
|
36
51
|
return nil
|
|
37
52
|
}
|
|
38
53
|
|
|
54
|
+
func newPacketRecordingWriter() *packetRecordingWriter {
|
|
55
|
+
return &packetRecordingWriter{
|
|
56
|
+
packets: make(chan *Packet, 1),
|
|
57
|
+
closed: make(chan struct{}),
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
func (w *packetRecordingWriter) WritePacket(pkt *Packet) error {
|
|
62
|
+
w.packets <- pkt
|
|
63
|
+
return nil
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
func (w *packetRecordingWriter) Close() error {
|
|
67
|
+
w.once.Do(func() {
|
|
68
|
+
close(w.closed)
|
|
69
|
+
})
|
|
70
|
+
return nil
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
func newBlockingPacketWriter() *blockingPacketWriter {
|
|
74
|
+
return &blockingPacketWriter{
|
|
75
|
+
packets: make(chan *Packet, 1),
|
|
76
|
+
releaseWrite: make(chan struct{}),
|
|
77
|
+
closed: make(chan struct{}),
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
func (w *blockingPacketWriter) WritePacket(pkt *Packet) error {
|
|
82
|
+
w.packets <- pkt
|
|
83
|
+
<-w.releaseWrite
|
|
84
|
+
return nil
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
func (w *blockingPacketWriter) Close() error {
|
|
88
|
+
w.once.Do(func() {
|
|
89
|
+
close(w.closed)
|
|
90
|
+
})
|
|
91
|
+
return nil
|
|
92
|
+
}
|
|
93
|
+
|
|
39
94
|
func TestCommonRPCHandleStreamCloseIdempotent(t *testing.T) {
|
|
40
95
|
writer := &closeCountingPacketWriter{}
|
|
41
96
|
rpc := NewServerRPC(context.Background(), InvokerFunc(nil), writer)
|
|
@@ -48,6 +103,23 @@ func TestCommonRPCHandleStreamCloseIdempotent(t *testing.T) {
|
|
|
48
103
|
}
|
|
49
104
|
}
|
|
50
105
|
|
|
106
|
+
func TestCommonRPCCancelContextIdempotent(t *testing.T) {
|
|
107
|
+
var calls atomic.Int32
|
|
108
|
+
rpc := &commonRPC{
|
|
109
|
+
ctx: context.Background(),
|
|
110
|
+
ctxCancel: func() {
|
|
111
|
+
calls.Add(1)
|
|
112
|
+
},
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
rpc.cancelContext()
|
|
116
|
+
rpc.cancelContext()
|
|
117
|
+
|
|
118
|
+
if got := calls.Load(); got != 1 {
|
|
119
|
+
t.Fatalf("expected context cancel once, got %d", got)
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
|
|
51
123
|
func TestCommonRPCHandleStreamCloseClosesWriterOutsideBroadcastLock(t *testing.T) {
|
|
52
124
|
var rpc *ServerRPC
|
|
53
125
|
writerClosedOutsideLock := false
|
|
@@ -69,6 +141,176 @@ func TestCommonRPCHandleStreamCloseClosesWriterOutsideBroadcastLock(t *testing.T
|
|
|
69
141
|
}
|
|
70
142
|
}
|
|
71
143
|
|
|
144
|
+
func TestServerRPCWaitReturnsAfterLocalInvokeCompletion(t *testing.T) {
|
|
145
|
+
writer := newPacketRecordingWriter()
|
|
146
|
+
streamCtxCh := make(chan context.Context, 1)
|
|
147
|
+
invokeErrCh := make(chan string, 1)
|
|
148
|
+
rpc := NewServerRPC(context.Background(), InvokerFunc(func(serviceID, methodID string, strm Stream) (bool, error) {
|
|
149
|
+
if serviceID != "service" || methodID != "method" {
|
|
150
|
+
invokeErrCh <- serviceID + "/" + methodID
|
|
151
|
+
}
|
|
152
|
+
streamCtxCh <- strm.Context()
|
|
153
|
+
return true, nil
|
|
154
|
+
}), writer)
|
|
155
|
+
|
|
156
|
+
if err := rpc.HandleCallStart(NewCallStartPacket("service", "method", nil, false).GetCallStart()); err != nil {
|
|
157
|
+
t.Fatalf("handle call start: %v", err)
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
waitCtx, waitCancel := context.WithTimeout(context.Background(), time.Second)
|
|
161
|
+
defer waitCancel()
|
|
162
|
+
if err := rpc.Wait(waitCtx); err != nil {
|
|
163
|
+
t.Fatalf("wait: %v", err)
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
select {
|
|
167
|
+
case got := <-invokeErrCh:
|
|
168
|
+
t.Fatalf("unexpected invoke target: %s", got)
|
|
169
|
+
default:
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
var streamCtx context.Context
|
|
173
|
+
select {
|
|
174
|
+
case streamCtx = <-streamCtxCh:
|
|
175
|
+
case <-time.After(time.Second):
|
|
176
|
+
t.Fatal("invoker did not receive stream context")
|
|
177
|
+
}
|
|
178
|
+
if err := streamCtx.Err(); err != nil {
|
|
179
|
+
t.Fatalf("normal invoke completion canceled stream context: %v", err)
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
select {
|
|
183
|
+
case pkt := <-writer.packets:
|
|
184
|
+
callData := pkt.GetCallData()
|
|
185
|
+
if callData == nil || !callData.GetComplete() || callData.GetError() != "" {
|
|
186
|
+
t.Fatalf("expected successful completion packet, got %#v", pkt.GetBody())
|
|
187
|
+
}
|
|
188
|
+
default:
|
|
189
|
+
t.Fatal("expected completion packet")
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
select {
|
|
193
|
+
case <-writer.closed:
|
|
194
|
+
default:
|
|
195
|
+
t.Fatal("expected writer closed")
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
func TestServerRPCRemoteCloseAfterLocalCompletionDoesNotCancelStreamContext(t *testing.T) {
|
|
200
|
+
writer := newPacketRecordingWriter()
|
|
201
|
+
streamCtxCh := make(chan context.Context, 1)
|
|
202
|
+
rpc := NewServerRPC(context.Background(), InvokerFunc(func(serviceID, methodID string, strm Stream) (bool, error) {
|
|
203
|
+
streamCtxCh <- strm.Context()
|
|
204
|
+
return true, nil
|
|
205
|
+
}), writer)
|
|
206
|
+
|
|
207
|
+
if err := rpc.HandleCallStart(NewCallStartPacket("service", "method", nil, false).GetCallStart()); err != nil {
|
|
208
|
+
t.Fatalf("handle call start: %v", err)
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
var streamCtx context.Context
|
|
212
|
+
select {
|
|
213
|
+
case streamCtx = <-streamCtxCh:
|
|
214
|
+
case <-time.After(time.Second):
|
|
215
|
+
t.Fatal("invoker did not receive stream context")
|
|
216
|
+
}
|
|
217
|
+
select {
|
|
218
|
+
case <-writer.closed:
|
|
219
|
+
case <-time.After(time.Second):
|
|
220
|
+
t.Fatal("writer did not close")
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
rpc.HandleStreamClose(nil)
|
|
224
|
+
|
|
225
|
+
if err := streamCtx.Err(); err != nil {
|
|
226
|
+
t.Fatalf("normal remote close after local completion canceled stream context: %v", err)
|
|
227
|
+
}
|
|
228
|
+
waitCtx, waitCancel := context.WithTimeout(context.Background(), time.Second)
|
|
229
|
+
defer waitCancel()
|
|
230
|
+
if err := rpc.Wait(waitCtx); err != nil {
|
|
231
|
+
t.Fatalf("wait: %v", err)
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
func TestServerRPCRemoteCloseDuringLocalCompletionDoesNotCancelStreamContext(t *testing.T) {
|
|
236
|
+
writer := newBlockingPacketWriter()
|
|
237
|
+
streamCtxCh := make(chan context.Context, 1)
|
|
238
|
+
rpc := NewServerRPC(context.Background(), InvokerFunc(func(serviceID, methodID string, strm Stream) (bool, error) {
|
|
239
|
+
streamCtxCh <- strm.Context()
|
|
240
|
+
return true, nil
|
|
241
|
+
}), writer)
|
|
242
|
+
|
|
243
|
+
if err := rpc.HandleCallStart(NewCallStartPacket("service", "method", nil, false).GetCallStart()); err != nil {
|
|
244
|
+
t.Fatalf("handle call start: %v", err)
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
var streamCtx context.Context
|
|
248
|
+
select {
|
|
249
|
+
case streamCtx = <-streamCtxCh:
|
|
250
|
+
case <-time.After(time.Second):
|
|
251
|
+
t.Fatal("invoker did not receive stream context")
|
|
252
|
+
}
|
|
253
|
+
select {
|
|
254
|
+
case <-writer.packets:
|
|
255
|
+
case <-time.After(time.Second):
|
|
256
|
+
t.Fatal("terminal packet write did not start")
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
rpc.HandleStreamClose(nil)
|
|
260
|
+
|
|
261
|
+
if err := streamCtx.Err(); err != nil {
|
|
262
|
+
t.Fatalf("normal remote close during local completion canceled stream context: %v", err)
|
|
263
|
+
}
|
|
264
|
+
close(writer.releaseWrite)
|
|
265
|
+
select {
|
|
266
|
+
case <-writer.closed:
|
|
267
|
+
case <-time.After(time.Second):
|
|
268
|
+
t.Fatal("writer did not close")
|
|
269
|
+
}
|
|
270
|
+
waitCtx, waitCancel := context.WithTimeout(context.Background(), time.Second)
|
|
271
|
+
defer waitCancel()
|
|
272
|
+
if err := rpc.Wait(waitCtx); err != nil {
|
|
273
|
+
t.Fatalf("wait: %v", err)
|
|
274
|
+
}
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
func TestServerRPCRemoteErrorDuringLocalCompletionCancelsStreamContext(t *testing.T) {
|
|
278
|
+
writer := newBlockingPacketWriter()
|
|
279
|
+
streamCtxCh := make(chan context.Context, 1)
|
|
280
|
+
rpc := NewServerRPC(context.Background(), InvokerFunc(func(serviceID, methodID string, strm Stream) (bool, error) {
|
|
281
|
+
streamCtxCh <- strm.Context()
|
|
282
|
+
return true, nil
|
|
283
|
+
}), writer)
|
|
284
|
+
|
|
285
|
+
if err := rpc.HandleCallStart(NewCallStartPacket("service", "method", nil, false).GetCallStart()); err != nil {
|
|
286
|
+
t.Fatalf("handle call start: %v", err)
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
var streamCtx context.Context
|
|
290
|
+
select {
|
|
291
|
+
case streamCtx = <-streamCtxCh:
|
|
292
|
+
case <-time.After(time.Second):
|
|
293
|
+
t.Fatal("invoker did not receive stream context")
|
|
294
|
+
}
|
|
295
|
+
select {
|
|
296
|
+
case <-writer.packets:
|
|
297
|
+
case <-time.After(time.Second):
|
|
298
|
+
t.Fatal("terminal packet write did not start")
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
rpc.HandleStreamClose(context.Canceled)
|
|
302
|
+
|
|
303
|
+
if err := streamCtx.Err(); err != context.Canceled {
|
|
304
|
+
t.Fatalf("remote error during local completion context error = %v want %v", err, context.Canceled)
|
|
305
|
+
}
|
|
306
|
+
close(writer.releaseWrite)
|
|
307
|
+
waitCtx, waitCancel := context.WithTimeout(context.Background(), time.Second)
|
|
308
|
+
defer waitCancel()
|
|
309
|
+
if err := rpc.Wait(waitCtx); err != context.Canceled {
|
|
310
|
+
t.Fatalf("wait = %v want %v", err, context.Canceled)
|
|
311
|
+
}
|
|
312
|
+
}
|
|
313
|
+
|
|
72
314
|
func TestClientRPCCloseAfterRemoteCompleteClosesWriter(t *testing.T) {
|
|
73
315
|
writer := &closeCountingPacketWriter{}
|
|
74
316
|
rpc := NewClientRPC(context.Background(), "service", "method")
|
package/srpc/rwc-conn.go
CHANGED
|
@@ -7,6 +7,8 @@ import (
|
|
|
7
7
|
"os"
|
|
8
8
|
"sync"
|
|
9
9
|
"time"
|
|
10
|
+
|
|
11
|
+
"github.com/aperturerobotics/starpc/internal/contextutil"
|
|
10
12
|
)
|
|
11
13
|
|
|
12
14
|
// connPktSize is the size of the buffers to use for packets for the RwcConn.
|
|
@@ -68,7 +70,7 @@ func NewRwcConn(
|
|
|
68
70
|
laddr, raddr net.Addr,
|
|
69
71
|
bufferPacketN int,
|
|
70
72
|
) *RwcConn {
|
|
71
|
-
ctx, ctxCancel :=
|
|
73
|
+
ctx, ctxCancel := contextutil.WithCancel(ctx)
|
|
72
74
|
if bufferPacketN <= 0 {
|
|
73
75
|
bufferPacketN = 10
|
|
74
76
|
}
|
package/srpc/server-rpc.go
CHANGED
|
@@ -91,13 +91,13 @@ func (r *ServerRPC) HandleCallStart(pkt *CallStart) error {
|
|
|
91
91
|
// invokeRPC invokes the RPC after CallStart is received.
|
|
92
92
|
func (r *ServerRPC) invokeRPC(serviceID, methodID string) {
|
|
93
93
|
// on the server side, the writer is closed by invokeRPC.
|
|
94
|
-
strm := NewMsgStream(r.ctx, r, r.
|
|
94
|
+
strm := NewMsgStream(r.ctx, r, r.cancelContext)
|
|
95
95
|
ok, err := r.invoker.InvokeMethod(serviceID, methodID, strm)
|
|
96
96
|
if err == nil && !ok {
|
|
97
97
|
err = ErrUnimplemented
|
|
98
98
|
}
|
|
99
|
+
r.beginLocalCompletion()
|
|
99
100
|
outPkt := NewCallDataPacket(nil, false, true, err)
|
|
100
101
|
_ = r.writer.WritePacket(outPkt)
|
|
101
|
-
|
|
102
|
-
r.ctxCancel()
|
|
102
|
+
r.finishLocalCompletion()
|
|
103
103
|
}
|
package/srpc/server.go
CHANGED
|
@@ -3,6 +3,8 @@ package srpc
|
|
|
3
3
|
import (
|
|
4
4
|
"context"
|
|
5
5
|
"io"
|
|
6
|
+
|
|
7
|
+
"github.com/aperturerobotics/starpc/internal/contextutil"
|
|
6
8
|
)
|
|
7
9
|
|
|
8
10
|
// Server handles incoming RPC streams with a mux.
|
|
@@ -26,7 +28,7 @@ func (s *Server) GetInvoker() Invoker {
|
|
|
26
28
|
// HandleStream handles an incoming stream and runs the read loop.
|
|
27
29
|
// Uses length-prefixed packets.
|
|
28
30
|
func (s *Server) HandleStream(ctx context.Context, rwc io.ReadWriteCloser) {
|
|
29
|
-
subCtx, subCtxCancel :=
|
|
31
|
+
subCtx, subCtxCancel := contextutil.WithCancel(ctx)
|
|
30
32
|
defer subCtxCancel()
|
|
31
33
|
prw := NewPacketReadWriter(rwc)
|
|
32
34
|
serverRPC := NewServerRPC(subCtx, s.invoker, prw)
|
package/srpc/stream-pipe.go
CHANGED
|
@@ -4,6 +4,8 @@ import (
|
|
|
4
4
|
"context"
|
|
5
5
|
"io"
|
|
6
6
|
"sync"
|
|
7
|
+
|
|
8
|
+
"github.com/aperturerobotics/starpc/internal/contextutil"
|
|
7
9
|
)
|
|
8
10
|
|
|
9
11
|
// pipeStream implements an in-memory stream.
|
|
@@ -22,9 +24,9 @@ type pipeStream struct {
|
|
|
22
24
|
// NewPipeStream constructs a new in-memory stream.
|
|
23
25
|
func NewPipeStream(ctx context.Context) (Stream, Stream) {
|
|
24
26
|
s1 := &pipeStream{dataCh: make(chan []byte, 5)}
|
|
25
|
-
s1.ctx, s1.ctxCancel =
|
|
27
|
+
s1.ctx, s1.ctxCancel = contextutil.WithCancel(ctx)
|
|
26
28
|
s2 := &pipeStream{other: s1, dataCh: make(chan []byte, 5)}
|
|
27
|
-
s2.ctx, s2.ctxCancel =
|
|
29
|
+
s2.ctx, s2.ctxCancel = contextutil.WithCancel(ctx)
|
|
28
30
|
s1.other = s2
|
|
29
31
|
return s1, s2
|
|
30
32
|
}
|