starpc 0.49.8 → 0.49.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +5 -5
- package/dist/srpc/packet.d.ts +4 -4
- package/go.mod +3 -3
- package/go.sum +8 -6
- package/package.json +2 -2
- package/srpc/build.rs +593 -0
- package/srpc/common-rpc.go +11 -1
- package/srpc/common-rpc_test.go +65 -0
- package/srpc/lib.rs +4 -1
- package/srpc/packet-rw.go +57 -10
- package/srpc/packet-rw_test.go +68 -0
package/README.md
CHANGED
|
@@ -151,20 +151,20 @@ Add the dependencies to your `Cargo.toml`:
|
|
|
151
151
|
|
|
152
152
|
```toml
|
|
153
153
|
[dependencies]
|
|
154
|
-
starpc = "0.
|
|
155
|
-
prost = "0.
|
|
154
|
+
starpc = "0.49"
|
|
155
|
+
prost = "0.14"
|
|
156
156
|
tokio = { version = "1", features = ["rt", "macros"] }
|
|
157
157
|
|
|
158
158
|
[build-dependencies]
|
|
159
|
-
starpc
|
|
160
|
-
prost-build = "0.
|
|
159
|
+
starpc = { version = "0.49", features = ["build"] }
|
|
160
|
+
prost-build = "0.14"
|
|
161
161
|
```
|
|
162
162
|
|
|
163
163
|
Create a `build.rs` to generate code from your proto files:
|
|
164
164
|
|
|
165
165
|
```rust
|
|
166
166
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
167
|
-
|
|
167
|
+
starpc::build::configure()
|
|
168
168
|
.compile_protos(&["proto/echo.proto"], &["proto"])?;
|
|
169
169
|
Ok(())
|
|
170
170
|
}
|
package/dist/srpc/packet.d.ts
CHANGED
|
@@ -7,14 +7,14 @@ export declare function uint32LEDecode(data: Uint8ArrayList): number;
|
|
|
7
7
|
export declare namespace uint32LEDecode {
|
|
8
8
|
var bytes: number;
|
|
9
9
|
}
|
|
10
|
-
export declare function uint32LEEncode(value: number): Uint8ArrayList
|
|
10
|
+
export declare function uint32LEEncode(value: number): Uint8ArrayList<ArrayBuffer>;
|
|
11
11
|
export declare namespace uint32LEEncode {
|
|
12
12
|
var bytes: number;
|
|
13
13
|
}
|
|
14
|
-
export declare function lengthPrefixEncode(source: Source<Uint8Array | Uint8ArrayList>, lengthEncoder: typeof uint32LEEncode): AsyncGenerator<Uint8ArrayList
|
|
15
|
-
export declare function lengthPrefixDecode(source: Source<Uint8Array | Uint8ArrayList>, lengthDecoder: typeof uint32LEDecode): AsyncGenerator<Uint8ArrayList
|
|
14
|
+
export declare function lengthPrefixEncode(source: Source<Uint8Array | Uint8ArrayList>, lengthEncoder: typeof uint32LEEncode): AsyncGenerator<Uint8ArrayList<ArrayBufferLike>, void, unknown>;
|
|
15
|
+
export declare function lengthPrefixDecode(source: Source<Uint8Array | Uint8ArrayList>, lengthDecoder: typeof uint32LEDecode): AsyncGenerator<Uint8ArrayList<ArrayBufferLike>, void, unknown>;
|
|
16
16
|
export declare function prependLengthPrefixTransform(lengthEncoder?: {
|
|
17
|
-
(value: number): Uint8ArrayList
|
|
17
|
+
(value: number): Uint8ArrayList<ArrayBuffer>;
|
|
18
18
|
bytes: number;
|
|
19
19
|
}): Transform<Source<Uint8Array | Uint8ArrayList>, AsyncGenerator<Uint8ArrayList, void, undefined> | Generator<Uint8ArrayList, void, undefined>>;
|
|
20
20
|
export declare function parseLengthPrefixTransform(lengthDecoder?: {
|
package/go.mod
CHANGED
|
@@ -5,7 +5,7 @@ go 1.25.0
|
|
|
5
5
|
require (
|
|
6
6
|
github.com/aperturerobotics/common v0.33.0 // latest
|
|
7
7
|
github.com/aperturerobotics/protobuf-go-lite v0.13.0 // latest
|
|
8
|
-
github.com/aperturerobotics/util v1.34.
|
|
8
|
+
github.com/aperturerobotics/util v1.34.5-0.20260516103104-cbfc6d6a0589 // latest
|
|
9
9
|
)
|
|
10
10
|
|
|
11
11
|
require (
|
|
@@ -21,7 +21,7 @@ require (
|
|
|
21
21
|
require (
|
|
22
22
|
github.com/libp2p/go-yamux/v4 v4.0.2 // latest
|
|
23
23
|
github.com/pkg/errors v0.9.1 // latest
|
|
24
|
-
github.com/sirupsen/logrus v1.9.5-0.
|
|
24
|
+
github.com/sirupsen/logrus v1.9.5-0.20260508084601-d4a50659cfd6 // latest
|
|
25
25
|
)
|
|
26
26
|
|
|
27
27
|
require (
|
|
@@ -29,5 +29,5 @@ require (
|
|
|
29
29
|
github.com/tetratelabs/wazero v1.11.0 // indirect
|
|
30
30
|
github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 // indirect
|
|
31
31
|
golang.org/x/mod v0.35.0 // indirect
|
|
32
|
-
golang.org/x/sys v0.
|
|
32
|
+
golang.org/x/sys v0.44.0 // indirect
|
|
33
33
|
)
|
package/go.sum
CHANGED
|
@@ -16,8 +16,10 @@ github.com/aperturerobotics/protobuf v0.0.0-20260203024654-8201686529c4 h1:4Dy3B
|
|
|
16
16
|
github.com/aperturerobotics/protobuf v0.0.0-20260203024654-8201686529c4/go.mod h1:tMgO7y6SJo/d9ZcvrpNqIQtdYT9de+QmYaHOZ4KnhOg=
|
|
17
17
|
github.com/aperturerobotics/protobuf-go-lite v0.13.0 h1:jEvCJhHaJEikDY/va2AUnS0DOb/0n82aISLAqxSh4Sk=
|
|
18
18
|
github.com/aperturerobotics/protobuf-go-lite v0.13.0/go.mod h1:lGH3s5ArCTXKI4wJdlNpaybUtwSjfAG0vdWjxOfMcF8=
|
|
19
|
-
github.com/aperturerobotics/util v1.34.
|
|
20
|
-
github.com/aperturerobotics/util v1.34.
|
|
19
|
+
github.com/aperturerobotics/util v1.34.5-0.20260515183346-68f9eac1d69f h1:xISFLs00h441uZcMVxhZbLIZsMRcjOM5Yont18i7WjA=
|
|
20
|
+
github.com/aperturerobotics/util v1.34.5-0.20260515183346-68f9eac1d69f/go.mod h1:mDe7WnncVuV7yjeeVSsagyfrw4xfncu7d+f0+d70niY=
|
|
21
|
+
github.com/aperturerobotics/util v1.34.5-0.20260516103104-cbfc6d6a0589 h1:8B9O13He1sz8Spr2pc+RL3hBzAMveLgUCXT7BpAfvEY=
|
|
22
|
+
github.com/aperturerobotics/util v1.34.5-0.20260516103104-cbfc6d6a0589/go.mod h1:mDe7WnncVuV7yjeeVSsagyfrw4xfncu7d+f0+d70niY=
|
|
21
23
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
|
22
24
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
|
23
25
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
|
@@ -30,8 +32,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
|
|
30
32
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
|
31
33
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
|
32
34
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
|
33
|
-
github.com/sirupsen/logrus v1.9.5-0.
|
|
34
|
-
github.com/sirupsen/logrus v1.9.5-0.
|
|
35
|
+
github.com/sirupsen/logrus v1.9.5-0.20260508084601-d4a50659cfd6 h1:D6qewLO/pJ+JvUwuri1dmCbc32d/eO+xyqg+p3N+9kA=
|
|
36
|
+
github.com/sirupsen/logrus v1.9.5-0.20260508084601-d4a50659cfd6/go.mod h1:FXZFonkDAnFozmO+5hGAFvB0Yg9/j2SIhA/QuIkP180=
|
|
35
37
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
|
36
38
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
|
37
39
|
github.com/tetratelabs/wazero v1.11.0 h1:+gKemEuKCTevU4d7ZTzlsvgd1uaToIDtlQlmNbwqYhA=
|
|
@@ -40,8 +42,8 @@ github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAz
|
|
|
40
42
|
github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
|
|
41
43
|
golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM=
|
|
42
44
|
golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU=
|
|
43
|
-
golang.org/x/sys v0.
|
|
44
|
-
golang.org/x/sys v0.
|
|
45
|
+
golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ=
|
|
46
|
+
golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
|
45
47
|
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
|
46
48
|
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
|
47
49
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "starpc",
|
|
3
|
-
"version": "0.49.
|
|
3
|
+
"version": "0.49.10",
|
|
4
4
|
"description": "Streaming protobuf RPC service protocol over any two-way channel.",
|
|
5
5
|
"license": "MIT",
|
|
6
6
|
"author": {
|
|
@@ -137,7 +137,7 @@
|
|
|
137
137
|
"it-pipe": "^3.0.1",
|
|
138
138
|
"it-pushable": "^3.2.3",
|
|
139
139
|
"it-stream-types": "^2.0.4",
|
|
140
|
-
"uint8arraylist": "^
|
|
140
|
+
"uint8arraylist": "^3.0.0",
|
|
141
141
|
"ws": "^8.20.0"
|
|
142
142
|
}
|
|
143
143
|
}
|
package/srpc/build.rs
ADDED
|
@@ -0,0 +1,593 @@
|
|
|
1
|
+
//! Build-script helpers for generating Starpc Rust service bindings.
|
|
2
|
+
//!
|
|
3
|
+
//! Enable the `build` feature and call [`configure`] from a downstream
|
|
4
|
+
//! `build.rs` to install the Starpc service generator into `prost-build`.
|
|
5
|
+
|
|
6
|
+
/// Configures `prost-build` to generate Starpc service bindings.
|
|
7
|
+
///
|
|
8
|
+
/// The returned config generates normal Prost message types and appends Starpc
|
|
9
|
+
/// client/server traits, client implementations, stream wrappers, and handler
|
|
10
|
+
/// glue for every protobuf service.
|
|
11
|
+
pub fn configure() -> prost_build::Config {
|
|
12
|
+
let mut config = prost_build::Config::new();
|
|
13
|
+
if let Ok(protoc) = protoc_bin_vendored::protoc_bin_path() {
|
|
14
|
+
config.protoc_executable(protoc);
|
|
15
|
+
}
|
|
16
|
+
if let Ok(include) = protoc_bin_vendored::include_path() {
|
|
17
|
+
config.protoc_arg(format!("--proto_path={}", include.display()));
|
|
18
|
+
}
|
|
19
|
+
config.extern_path(".rpcstream", "::starpc::rpcstream");
|
|
20
|
+
config.service_generator(Box::new(StarpcServiceGenerator::default()));
|
|
21
|
+
config
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
/// StarpcServiceGenerator emits Starpc client and server glue for Prost services.
|
|
25
|
+
#[derive(Debug, Default)]
|
|
26
|
+
pub struct StarpcServiceGenerator;
|
|
27
|
+
|
|
28
|
+
impl prost_build::ServiceGenerator for StarpcServiceGenerator {
|
|
29
|
+
fn generate(&mut self, service: prost_build::Service, buf: &mut String) {
|
|
30
|
+
let mut gen = Generator::new(service, buf);
|
|
31
|
+
gen.generate();
|
|
32
|
+
}
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
struct Generator<'a> {
|
|
36
|
+
service: prost_build::Service,
|
|
37
|
+
buf: &'a mut String,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
impl<'a> Generator<'a> {
|
|
41
|
+
fn new(service: prost_build::Service, buf: &'a mut String) -> Self {
|
|
42
|
+
Self { service, buf }
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
fn generate(&mut self) {
|
|
46
|
+
let service_id = self.service_id();
|
|
47
|
+
let service_name = self.service.name.clone();
|
|
48
|
+
|
|
49
|
+
self.line("");
|
|
50
|
+
self.line("#[allow(unused_imports)]");
|
|
51
|
+
self.line("use starpc::StreamExt;");
|
|
52
|
+
self.line("");
|
|
53
|
+
self.line(&format!("/// Service ID for {}.", service_name));
|
|
54
|
+
self.line(&format!(
|
|
55
|
+
"pub const {}: &str = {:?};",
|
|
56
|
+
service_id_const(&service_name),
|
|
57
|
+
service_id
|
|
58
|
+
));
|
|
59
|
+
self.line("");
|
|
60
|
+
|
|
61
|
+
self.generate_stream_traits();
|
|
62
|
+
self.generate_client_trait();
|
|
63
|
+
self.generate_client_impl();
|
|
64
|
+
self.generate_stream_impls();
|
|
65
|
+
self.generate_server_trait();
|
|
66
|
+
self.generate_handler();
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
fn generate_stream_traits(&mut self) {
|
|
70
|
+
let service_name = self.service.name.clone();
|
|
71
|
+
for method in self.service.methods.clone() {
|
|
72
|
+
if !method.client_streaming && !method.server_streaming {
|
|
73
|
+
continue;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
let stream_trait = stream_trait_name(&service_name, &method);
|
|
77
|
+
|
|
78
|
+
self.line(&format!(
|
|
79
|
+
"/// Stream trait for {}.{}.",
|
|
80
|
+
service_name, method.proto_name
|
|
81
|
+
));
|
|
82
|
+
self.line("#[starpc::async_trait]");
|
|
83
|
+
self.line(&format!("pub trait {}: Send + Sync {{", stream_trait));
|
|
84
|
+
self.line(" /// Returns the context for this stream.");
|
|
85
|
+
self.line(" fn context(&self) -> &starpc::Context;");
|
|
86
|
+
|
|
87
|
+
if method.client_streaming {
|
|
88
|
+
self.line(" /// Sends a message on the stream.");
|
|
89
|
+
self.line(&format!(
|
|
90
|
+
" async fn send(&self, msg: &{}) -> starpc::Result<()>;",
|
|
91
|
+
method.input_type
|
|
92
|
+
));
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
if method.server_streaming {
|
|
96
|
+
self.line(" /// Receives a message from the stream.");
|
|
97
|
+
self.line(&format!(
|
|
98
|
+
" async fn recv(&self) -> starpc::Result<{}>;",
|
|
99
|
+
method.output_type
|
|
100
|
+
));
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
if method.client_streaming && !method.server_streaming {
|
|
104
|
+
self.line(" /// Closes the send side and receives the response.");
|
|
105
|
+
self.line(&format!(
|
|
106
|
+
" async fn close_and_recv(&self) -> starpc::Result<{}>;",
|
|
107
|
+
method.output_type
|
|
108
|
+
));
|
|
109
|
+
} else {
|
|
110
|
+
self.line(" /// Closes the stream.");
|
|
111
|
+
self.line(" async fn close(&self) -> starpc::Result<()>;");
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
self.line("}");
|
|
115
|
+
self.line("");
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
fn generate_client_trait(&mut self) {
|
|
120
|
+
let service_name = self.service.name.clone();
|
|
121
|
+
self.line(&format!("/// Client trait for {}.", service_name));
|
|
122
|
+
self.line("#[starpc::async_trait]");
|
|
123
|
+
self.line(&format!("pub trait {}Client: Send + Sync {{", service_name));
|
|
124
|
+
|
|
125
|
+
for method in self.service.methods.clone() {
|
|
126
|
+
self.line(&format!(" /// {}.", method.proto_name));
|
|
127
|
+
self.line(&format!(
|
|
128
|
+
" {};",
|
|
129
|
+
client_trait_method(&service_name, &method)
|
|
130
|
+
));
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
self.line("}");
|
|
134
|
+
self.line("");
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
fn generate_client_impl(&mut self) {
|
|
138
|
+
let service_id = self.service_id();
|
|
139
|
+
let service_name = self.service.name.clone();
|
|
140
|
+
|
|
141
|
+
self.line(&format!("/// Client implementation for {}.", service_name));
|
|
142
|
+
self.line(&format!("pub struct {}ClientImpl<C> {{", service_name));
|
|
143
|
+
self.line(" client: C,");
|
|
144
|
+
self.line("}");
|
|
145
|
+
self.line("");
|
|
146
|
+
self.line(&format!(
|
|
147
|
+
"impl<C: starpc::Client> {}ClientImpl<C> {{",
|
|
148
|
+
service_name
|
|
149
|
+
));
|
|
150
|
+
self.line(" /// Creates a new client.");
|
|
151
|
+
self.line(" pub fn new(client: C) -> Self {");
|
|
152
|
+
self.line(" Self { client }");
|
|
153
|
+
self.line(" }");
|
|
154
|
+
self.line("}");
|
|
155
|
+
self.line("");
|
|
156
|
+
self.line("#[starpc::async_trait]");
|
|
157
|
+
self.line(&format!(
|
|
158
|
+
"impl<C: starpc::Client + 'static> {}Client for {}ClientImpl<C> {{",
|
|
159
|
+
service_name, service_name
|
|
160
|
+
));
|
|
161
|
+
|
|
162
|
+
for method in self.service.methods.clone() {
|
|
163
|
+
let method_name = &method.name;
|
|
164
|
+
if method.client_streaming && method.server_streaming {
|
|
165
|
+
self.line(&format!(
|
|
166
|
+
" async fn {}(&self) -> starpc::Result<Box<dyn {}>> {{",
|
|
167
|
+
method_name,
|
|
168
|
+
stream_trait_name(&service_name, &method)
|
|
169
|
+
));
|
|
170
|
+
self.line(&format!(
|
|
171
|
+
" let stream = self.client.new_stream({:?}, {:?}, None).await?;",
|
|
172
|
+
service_id, method.proto_name
|
|
173
|
+
));
|
|
174
|
+
self.line(&format!(
|
|
175
|
+
" Ok(Box::new({}Impl {{ stream }}))",
|
|
176
|
+
stream_trait_name(&service_name, &method)
|
|
177
|
+
));
|
|
178
|
+
self.line(" }");
|
|
179
|
+
} else if method.server_streaming {
|
|
180
|
+
self.line(&format!(
|
|
181
|
+
" async fn {}(&self, request: &{}) -> starpc::Result<Box<dyn {}>> {{",
|
|
182
|
+
method_name,
|
|
183
|
+
method.input_type,
|
|
184
|
+
stream_trait_name(&service_name, &method)
|
|
185
|
+
));
|
|
186
|
+
self.line(" use starpc::ProstMessage;");
|
|
187
|
+
self.line(" let data = request.encode_to_vec();");
|
|
188
|
+
self.line(&format!(
|
|
189
|
+
" let stream = self.client.new_stream({:?}, {:?}, Some(&data)).await?;",
|
|
190
|
+
service_id, method.proto_name
|
|
191
|
+
));
|
|
192
|
+
self.line(" stream.close_send().await?;");
|
|
193
|
+
self.line(&format!(
|
|
194
|
+
" Ok(Box::new({}Impl {{ stream }}))",
|
|
195
|
+
stream_trait_name(&service_name, &method)
|
|
196
|
+
));
|
|
197
|
+
self.line(" }");
|
|
198
|
+
} else if method.client_streaming {
|
|
199
|
+
self.line(&format!(
|
|
200
|
+
" async fn {}(&self) -> starpc::Result<Box<dyn {}>> {{",
|
|
201
|
+
method_name,
|
|
202
|
+
stream_trait_name(&service_name, &method)
|
|
203
|
+
));
|
|
204
|
+
self.line(&format!(
|
|
205
|
+
" let stream = self.client.new_stream({:?}, {:?}, None).await?;",
|
|
206
|
+
service_id, method.proto_name
|
|
207
|
+
));
|
|
208
|
+
self.line(&format!(
|
|
209
|
+
" Ok(Box::new({}Impl {{ stream }}))",
|
|
210
|
+
stream_trait_name(&service_name, &method)
|
|
211
|
+
));
|
|
212
|
+
self.line(" }");
|
|
213
|
+
} else {
|
|
214
|
+
self.line(&format!(
|
|
215
|
+
" async fn {}(&self, request: &{}) -> starpc::Result<{}> {{",
|
|
216
|
+
method_name, method.input_type, method.output_type
|
|
217
|
+
));
|
|
218
|
+
self.line(&format!(
|
|
219
|
+
" self.client.exec_call({:?}, {:?}, request).await",
|
|
220
|
+
service_id, method.proto_name
|
|
221
|
+
));
|
|
222
|
+
self.line(" }");
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
self.line("}");
|
|
227
|
+
self.line("");
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
fn generate_stream_impls(&mut self) {
|
|
231
|
+
let service_name = self.service.name.clone();
|
|
232
|
+
for method in self.service.methods.clone() {
|
|
233
|
+
if !method.client_streaming && !method.server_streaming {
|
|
234
|
+
continue;
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
let stream_trait = stream_trait_name(&service_name, &method);
|
|
238
|
+
|
|
239
|
+
self.line(&format!("struct {}Impl {{", stream_trait));
|
|
240
|
+
self.line(" stream: Box<dyn starpc::Stream>,");
|
|
241
|
+
self.line("}");
|
|
242
|
+
self.line("");
|
|
243
|
+
self.line("#[starpc::async_trait]");
|
|
244
|
+
self.line(&format!(
|
|
245
|
+
"impl {} for {}Impl {{",
|
|
246
|
+
stream_trait, stream_trait
|
|
247
|
+
));
|
|
248
|
+
self.line(" fn context(&self) -> &starpc::Context {");
|
|
249
|
+
self.line(" self.stream.context()");
|
|
250
|
+
self.line(" }");
|
|
251
|
+
|
|
252
|
+
if method.client_streaming {
|
|
253
|
+
self.line(&format!(
|
|
254
|
+
" async fn send(&self, msg: &{}) -> starpc::Result<()> {{",
|
|
255
|
+
method.input_type
|
|
256
|
+
));
|
|
257
|
+
self.line(" self.stream.msg_send(msg).await");
|
|
258
|
+
self.line(" }");
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
if method.server_streaming {
|
|
262
|
+
self.line(&format!(
|
|
263
|
+
" async fn recv(&self) -> starpc::Result<{}> {{",
|
|
264
|
+
method.output_type
|
|
265
|
+
));
|
|
266
|
+
self.line(" self.stream.msg_recv().await");
|
|
267
|
+
self.line(" }");
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
if method.client_streaming && !method.server_streaming {
|
|
271
|
+
self.line(&format!(
|
|
272
|
+
" async fn close_and_recv(&self) -> starpc::Result<{}> {{",
|
|
273
|
+
method.output_type
|
|
274
|
+
));
|
|
275
|
+
self.line(" self.stream.close_send().await?;");
|
|
276
|
+
self.line(" self.stream.msg_recv().await");
|
|
277
|
+
self.line(" }");
|
|
278
|
+
} else {
|
|
279
|
+
self.line(" async fn close(&self) -> starpc::Result<()> {");
|
|
280
|
+
self.line(" self.stream.close().await");
|
|
281
|
+
self.line(" }");
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
self.line("}");
|
|
285
|
+
self.line("");
|
|
286
|
+
}
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
fn generate_server_trait(&mut self) {
|
|
290
|
+
let service_name = self.service.name.clone();
|
|
291
|
+
|
|
292
|
+
self.line(&format!("/// Server trait for {}.", service_name));
|
|
293
|
+
self.line("#[starpc::async_trait]");
|
|
294
|
+
self.line(&format!("pub trait {}Server: Send + Sync {{", service_name));
|
|
295
|
+
|
|
296
|
+
for method in self.service.methods.clone() {
|
|
297
|
+
self.line(&format!(" /// {}.", method.proto_name));
|
|
298
|
+
if method.client_streaming && method.server_streaming {
|
|
299
|
+
self.line(&format!(
|
|
300
|
+
" async fn {}(&self, stream: Box<dyn starpc::Stream>) -> starpc::Result<()>;",
|
|
301
|
+
method.name
|
|
302
|
+
));
|
|
303
|
+
} else if method.server_streaming {
|
|
304
|
+
self.line(&format!(
|
|
305
|
+
" async fn {}(&self, request: {}, stream: Box<dyn starpc::Stream>) -> starpc::Result<()>;",
|
|
306
|
+
method.name, method.input_type
|
|
307
|
+
));
|
|
308
|
+
} else if method.client_streaming {
|
|
309
|
+
self.line(&format!(
|
|
310
|
+
" async fn {}(&self, stream: &dyn starpc::Stream) -> starpc::Result<{}>;",
|
|
311
|
+
method.name, method.output_type
|
|
312
|
+
));
|
|
313
|
+
} else {
|
|
314
|
+
self.line(&format!(
|
|
315
|
+
" async fn {}(&self, request: {}) -> starpc::Result<{}>;",
|
|
316
|
+
method.name, method.input_type, method.output_type
|
|
317
|
+
));
|
|
318
|
+
}
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
self.line("}");
|
|
322
|
+
self.line("");
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
fn generate_handler(&mut self) {
|
|
326
|
+
let service_id = self.service_id();
|
|
327
|
+
let service_name = self.service.name.clone();
|
|
328
|
+
let methods = self.service.methods.clone();
|
|
329
|
+
let method_ids = method_ids_const(&service_name);
|
|
330
|
+
|
|
331
|
+
self.line(&format!("const {}: &[&str] = &[", method_ids));
|
|
332
|
+
for method in &methods {
|
|
333
|
+
self.line(&format!(" {:?},", method.proto_name));
|
|
334
|
+
}
|
|
335
|
+
self.line("];");
|
|
336
|
+
self.line("");
|
|
337
|
+
self.line(&format!("/// Handler for {}.", service_name));
|
|
338
|
+
self.line(&format!(
|
|
339
|
+
"pub struct {}Handler<S: {}Server> {{",
|
|
340
|
+
service_name, service_name
|
|
341
|
+
));
|
|
342
|
+
self.line(" server: std::sync::Arc<S>,");
|
|
343
|
+
self.line("}");
|
|
344
|
+
self.line("");
|
|
345
|
+
self.line(&format!(
|
|
346
|
+
"impl<S: {}Server + 'static> {}Handler<S> {{",
|
|
347
|
+
service_name, service_name
|
|
348
|
+
));
|
|
349
|
+
self.line(" /// Creates a new handler wrapping the server implementation.");
|
|
350
|
+
self.line(" pub fn new(server: S) -> Self {");
|
|
351
|
+
self.line(" Self { server: std::sync::Arc::new(server) }");
|
|
352
|
+
self.line(" }");
|
|
353
|
+
self.line("");
|
|
354
|
+
self.line(" /// Creates a new handler with a shared server.");
|
|
355
|
+
self.line(" pub fn with_arc(server: std::sync::Arc<S>) -> Self {");
|
|
356
|
+
self.line(" Self { server }");
|
|
357
|
+
self.line(" }");
|
|
358
|
+
self.line("}");
|
|
359
|
+
self.line("");
|
|
360
|
+
self.line("#[starpc::async_trait]");
|
|
361
|
+
self.line(&format!(
|
|
362
|
+
"impl<S: {}Server + 'static> starpc::Invoker for {}Handler<S> {{",
|
|
363
|
+
service_name, service_name
|
|
364
|
+
));
|
|
365
|
+
self.line(" async fn invoke_method(");
|
|
366
|
+
self.line(" &self,");
|
|
367
|
+
self.line(" _service_id: &str,");
|
|
368
|
+
self.line(" method_id: &str,");
|
|
369
|
+
self.line(" stream: Box<dyn starpc::Stream>,");
|
|
370
|
+
self.line(" ) -> (bool, starpc::Result<()>) {");
|
|
371
|
+
self.line(" match method_id {");
|
|
372
|
+
|
|
373
|
+
for method in &methods {
|
|
374
|
+
self.line(&format!(" {:?} => {{", method.proto_name));
|
|
375
|
+
if method.client_streaming && method.server_streaming {
|
|
376
|
+
self.line(&format!(
|
|
377
|
+
" (true, self.server.{}(stream).await)",
|
|
378
|
+
method.name
|
|
379
|
+
));
|
|
380
|
+
} else if method.server_streaming {
|
|
381
|
+
self.line(&format!(
|
|
382
|
+
" let request: {} = match stream.msg_recv().await {{",
|
|
383
|
+
method.input_type
|
|
384
|
+
));
|
|
385
|
+
self.line(" Ok(r) => r,");
|
|
386
|
+
self.line(" Err(e) => return (true, Err(e)),");
|
|
387
|
+
self.line(" };");
|
|
388
|
+
self.line(&format!(
|
|
389
|
+
" (true, self.server.{}(request, stream).await)",
|
|
390
|
+
method.name
|
|
391
|
+
));
|
|
392
|
+
} else if method.client_streaming {
|
|
393
|
+
self.line(&format!(
|
|
394
|
+
" match self.server.{}(stream.as_ref()).await {{",
|
|
395
|
+
method.name
|
|
396
|
+
));
|
|
397
|
+
self.line(" Ok(response) => {");
|
|
398
|
+
self.line(
|
|
399
|
+
" if let Err(e) = stream.msg_send(&response).await {",
|
|
400
|
+
);
|
|
401
|
+
self.line(" return (true, Err(e));");
|
|
402
|
+
self.line(" }");
|
|
403
|
+
self.line(" (true, Ok(()))");
|
|
404
|
+
self.line(" }");
|
|
405
|
+
self.line(" Err(e) => (true, Err(e)),");
|
|
406
|
+
self.line(" }");
|
|
407
|
+
} else {
|
|
408
|
+
self.line(&format!(
|
|
409
|
+
" let request: {} = match stream.msg_recv().await {{",
|
|
410
|
+
method.input_type
|
|
411
|
+
));
|
|
412
|
+
self.line(" Ok(r) => r,");
|
|
413
|
+
self.line(" Err(e) => return (true, Err(e)),");
|
|
414
|
+
self.line(" };");
|
|
415
|
+
self.line(&format!(
|
|
416
|
+
" match self.server.{}(request).await {{",
|
|
417
|
+
method.name
|
|
418
|
+
));
|
|
419
|
+
self.line(" Ok(response) => {");
|
|
420
|
+
self.line(
|
|
421
|
+
" if let Err(e) = stream.msg_send(&response).await {",
|
|
422
|
+
);
|
|
423
|
+
self.line(" return (true, Err(e));");
|
|
424
|
+
self.line(" }");
|
|
425
|
+
self.line(" (true, Ok(()))");
|
|
426
|
+
self.line(" }");
|
|
427
|
+
self.line(" Err(e) => (true, Err(e)),");
|
|
428
|
+
self.line(" }");
|
|
429
|
+
}
|
|
430
|
+
self.line(" }");
|
|
431
|
+
}
|
|
432
|
+
|
|
433
|
+
self.line(" _ => (false, Err(starpc::Error::Unimplemented)),");
|
|
434
|
+
self.line(" }");
|
|
435
|
+
self.line(" }");
|
|
436
|
+
self.line("}");
|
|
437
|
+
self.line("");
|
|
438
|
+
self.line(&format!(
|
|
439
|
+
"impl<S: {}Server + 'static> starpc::Handler for {}Handler<S> {{",
|
|
440
|
+
service_name, service_name
|
|
441
|
+
));
|
|
442
|
+
self.line(" fn service_id(&self) -> &'static str {");
|
|
443
|
+
self.line(&format!(" {:?}", service_id));
|
|
444
|
+
self.line(" }");
|
|
445
|
+
self.line("");
|
|
446
|
+
self.line(" fn method_ids(&self) -> &'static [&'static str] {");
|
|
447
|
+
self.line(&format!(" {}", method_ids));
|
|
448
|
+
self.line(" }");
|
|
449
|
+
self.line("}");
|
|
450
|
+
self.line("");
|
|
451
|
+
}
|
|
452
|
+
|
|
453
|
+
fn service_id(&self) -> String {
|
|
454
|
+
if self.service.package.is_empty() {
|
|
455
|
+
self.service.proto_name.clone()
|
|
456
|
+
} else {
|
|
457
|
+
format!("{}.{}", self.service.package, self.service.proto_name)
|
|
458
|
+
}
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
fn line(&mut self, line: &str) {
|
|
462
|
+
self.buf.push_str(line);
|
|
463
|
+
self.buf.push('\n');
|
|
464
|
+
}
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
fn client_trait_method(service_name: &str, method: &prost_build::Method) -> String {
|
|
468
|
+
let stream_trait = stream_trait_name(service_name, method);
|
|
469
|
+
if method.client_streaming && method.server_streaming {
|
|
470
|
+
format!(
|
|
471
|
+
"async fn {}(&self) -> starpc::Result<Box<dyn {}>>",
|
|
472
|
+
method.name, stream_trait
|
|
473
|
+
)
|
|
474
|
+
} else if method.server_streaming {
|
|
475
|
+
format!(
|
|
476
|
+
"async fn {}(&self, request: &{}) -> starpc::Result<Box<dyn {}>>",
|
|
477
|
+
method.name, method.input_type, stream_trait
|
|
478
|
+
)
|
|
479
|
+
} else if method.client_streaming {
|
|
480
|
+
format!(
|
|
481
|
+
"async fn {}(&self) -> starpc::Result<Box<dyn {}>>",
|
|
482
|
+
method.name, stream_trait
|
|
483
|
+
)
|
|
484
|
+
} else {
|
|
485
|
+
format!(
|
|
486
|
+
"async fn {}(&self, request: &{}) -> starpc::Result<{}>",
|
|
487
|
+
method.name, method.input_type, method.output_type
|
|
488
|
+
)
|
|
489
|
+
}
|
|
490
|
+
}
|
|
491
|
+
|
|
492
|
+
fn stream_trait_name(service_name: &str, method: &prost_build::Method) -> String {
|
|
493
|
+
format!(
|
|
494
|
+
"{}{}Stream",
|
|
495
|
+
service_name,
|
|
496
|
+
upper_camel_from_snake(&method.name)
|
|
497
|
+
)
|
|
498
|
+
}
|
|
499
|
+
|
|
500
|
+
fn service_id_const(service_name: &str) -> String {
|
|
501
|
+
format!("{}_SERVICE_ID", screaming_snake(service_name))
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
fn method_ids_const(service_name: &str) -> String {
|
|
505
|
+
format!("{}_METHOD_IDS", screaming_snake(service_name))
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
fn screaming_snake(name: &str) -> String {
|
|
509
|
+
let mut out = String::new();
|
|
510
|
+
for (idx, ch) in name.chars().enumerate() {
|
|
511
|
+
if ch.is_uppercase() && idx != 0 {
|
|
512
|
+
out.push('_');
|
|
513
|
+
}
|
|
514
|
+
out.extend(ch.to_uppercase());
|
|
515
|
+
}
|
|
516
|
+
out
|
|
517
|
+
}
|
|
518
|
+
|
|
519
|
+
fn upper_camel_from_snake(name: &str) -> String {
|
|
520
|
+
let mut out = String::new();
|
|
521
|
+
let mut upper = true;
|
|
522
|
+
for ch in name.chars() {
|
|
523
|
+
if ch == '_' {
|
|
524
|
+
upper = true;
|
|
525
|
+
continue;
|
|
526
|
+
}
|
|
527
|
+
if upper {
|
|
528
|
+
out.extend(ch.to_uppercase());
|
|
529
|
+
upper = false;
|
|
530
|
+
} else {
|
|
531
|
+
out.push(ch);
|
|
532
|
+
}
|
|
533
|
+
}
|
|
534
|
+
out
|
|
535
|
+
}
|
|
536
|
+
|
|
537
|
+
#[cfg(test)]
|
|
538
|
+
mod tests {
|
|
539
|
+
use std::fs;
|
|
540
|
+
|
|
541
|
+
#[test]
|
|
542
|
+
fn configure_generates_starpc_service_glue() {
|
|
543
|
+
let root = std::env::temp_dir().join(format!(
|
|
544
|
+
"starpc-codegen-test-{}-{}",
|
|
545
|
+
std::process::id(),
|
|
546
|
+
std::time::SystemTime::now()
|
|
547
|
+
.duration_since(std::time::UNIX_EPOCH)
|
|
548
|
+
.unwrap()
|
|
549
|
+
.as_nanos()
|
|
550
|
+
));
|
|
551
|
+
let proto_dir = root.join("proto");
|
|
552
|
+
let out_dir = root.join("out");
|
|
553
|
+
let _ = fs::remove_dir_all(&root);
|
|
554
|
+
fs::create_dir_all(&proto_dir).unwrap();
|
|
555
|
+
fs::create_dir_all(&out_dir).unwrap();
|
|
556
|
+
|
|
557
|
+
let proto = proto_dir.join("test.proto");
|
|
558
|
+
fs::write(
|
|
559
|
+
&proto,
|
|
560
|
+
r#"syntax = "proto3";
|
|
561
|
+
package fixture;
|
|
562
|
+
|
|
563
|
+
service TestService {
|
|
564
|
+
rpc Unary(TestMsg) returns (TestMsg);
|
|
565
|
+
rpc ServerStream(TestMsg) returns (stream TestMsg);
|
|
566
|
+
rpc ClientStream(stream TestMsg) returns (TestMsg);
|
|
567
|
+
rpc Bidi(stream TestMsg) returns (stream TestMsg);
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
message TestMsg {
|
|
571
|
+
string body = 1;
|
|
572
|
+
}
|
|
573
|
+
"#,
|
|
574
|
+
)
|
|
575
|
+
.unwrap();
|
|
576
|
+
|
|
577
|
+
let mut config = super::configure();
|
|
578
|
+
config.out_dir(&out_dir);
|
|
579
|
+
config.compile_protos(&[proto], &[proto_dir]).unwrap();
|
|
580
|
+
|
|
581
|
+
let generated = fs::read_to_string(out_dir.join("fixture.rs")).unwrap();
|
|
582
|
+
assert!(generated.contains("pub const TEST_SERVICE_SERVICE_ID"));
|
|
583
|
+
assert!(generated.contains("pub trait TestServiceClient"));
|
|
584
|
+
assert!(generated.contains("pub struct TestServiceClientImpl"));
|
|
585
|
+
assert!(generated.contains("pub trait TestServiceServer"));
|
|
586
|
+
assert!(generated.contains("pub struct TestServiceHandler"));
|
|
587
|
+
assert!(generated.contains("TestServiceServerStreamStream"));
|
|
588
|
+
assert!(generated.contains("TestServiceClientStreamStream"));
|
|
589
|
+
assert!(generated.contains("TestServiceBidiStream"));
|
|
590
|
+
|
|
591
|
+
let _ = fs::remove_dir_all(&root);
|
|
592
|
+
}
|
|
593
|
+
}
|
package/srpc/common-rpc.go
CHANGED
|
@@ -146,15 +146,22 @@ func (c *commonRPC) WriteCallData(data []byte, dataIsZero, complete bool, err er
|
|
|
146
146
|
|
|
147
147
|
// HandleStreamClose handles the incoming stream closing w/ optional error.
|
|
148
148
|
func (c *commonRPC) HandleStreamClose(closeErr error) {
|
|
149
|
+
var writer PacketWriter
|
|
149
150
|
c.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
|
|
151
|
+
if c.dataClosed {
|
|
152
|
+
return
|
|
153
|
+
}
|
|
150
154
|
if closeErr != nil && c.remoteErr == nil {
|
|
151
155
|
c.remoteErr = closeErr
|
|
152
156
|
}
|
|
153
157
|
c.dataClosed = true
|
|
154
158
|
c.ctxCancel()
|
|
155
|
-
|
|
159
|
+
writer = c.writer
|
|
156
160
|
broadcast()
|
|
157
161
|
})
|
|
162
|
+
if writer != nil {
|
|
163
|
+
_ = writer.Close()
|
|
164
|
+
}
|
|
158
165
|
}
|
|
159
166
|
|
|
160
167
|
// HandleCallCancel handles the call cancel packet.
|
|
@@ -210,6 +217,9 @@ func (c *commonRPC) WriteCallCancel() error {
|
|
|
210
217
|
|
|
211
218
|
// closeLocked releases resources held by the RPC.
|
|
212
219
|
func (c *commonRPC) closeLocked(broadcast func()) {
|
|
220
|
+
if c.dataClosed {
|
|
221
|
+
return
|
|
222
|
+
}
|
|
213
223
|
c.dataClosed = true
|
|
214
224
|
c.localCompleted.Store(true)
|
|
215
225
|
if c.remoteErr == nil {
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
package srpc
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"context"
|
|
5
|
+
"io"
|
|
6
|
+
"sync/atomic"
|
|
7
|
+
"testing"
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
type closeCountingPacketWriter struct {
|
|
11
|
+
closed atomic.Int32
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
type closeCallbackPacketWriter struct {
|
|
15
|
+
closeFn func()
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
func (w *closeCountingPacketWriter) WritePacket(*Packet) error {
|
|
19
|
+
return nil
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
func (w *closeCountingPacketWriter) Close() error {
|
|
23
|
+
w.closed.Add(1)
|
|
24
|
+
return nil
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
func (w *closeCallbackPacketWriter) WritePacket(*Packet) error {
|
|
28
|
+
return nil
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
func (w *closeCallbackPacketWriter) Close() error {
|
|
32
|
+
if w.closeFn != nil {
|
|
33
|
+
w.closeFn()
|
|
34
|
+
}
|
|
35
|
+
return nil
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
func TestCommonRPCHandleStreamCloseIdempotent(t *testing.T) {
|
|
39
|
+
writer := &closeCountingPacketWriter{}
|
|
40
|
+
rpc := NewServerRPC(context.Background(), InvokerFunc(nil), writer)
|
|
41
|
+
|
|
42
|
+
rpc.HandleStreamClose(io.EOF)
|
|
43
|
+
rpc.HandleStreamClose(context.Canceled)
|
|
44
|
+
|
|
45
|
+
if got := writer.closed.Load(); got != 1 {
|
|
46
|
+
t.Fatalf("expected writer closed once, got %d", got)
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
func TestCommonRPCHandleStreamCloseClosesWriterOutsideBroadcastLock(t *testing.T) {
|
|
51
|
+
var rpc *ServerRPC
|
|
52
|
+
writerClosedOutsideLock := false
|
|
53
|
+
writer := &closeCallbackPacketWriter{
|
|
54
|
+
closeFn: func() {
|
|
55
|
+
writerClosedOutsideLock = rpc.bcast.TryHoldLock(func(func(), func() <-chan struct{}) {})
|
|
56
|
+
},
|
|
57
|
+
}
|
|
58
|
+
rpc = NewServerRPC(context.Background(), InvokerFunc(nil), writer)
|
|
59
|
+
|
|
60
|
+
rpc.HandleStreamClose(io.EOF)
|
|
61
|
+
|
|
62
|
+
if !writerClosedOutsideLock {
|
|
63
|
+
t.Fatal("expected writer close outside broadcast lock")
|
|
64
|
+
}
|
|
65
|
+
}
|
package/srpc/lib.rs
CHANGED
|
@@ -9,7 +9,7 @@
|
|
|
9
9
|
//! - **Wire-compatible** with the Go and TypeScript implementations
|
|
10
10
|
//! - **Streaming support** for all RPC patterns
|
|
11
11
|
//! - **Transport agnostic** - works with TCP, WebSocket, or any AsyncRead/AsyncWrite
|
|
12
|
-
//! - **Code generation** via
|
|
12
|
+
//! - **Code generation** via the optional `build` feature
|
|
13
13
|
//!
|
|
14
14
|
//! # Quick Start
|
|
15
15
|
//!
|
|
@@ -73,6 +73,9 @@ pub mod stream;
|
|
|
73
73
|
pub mod testing;
|
|
74
74
|
pub mod transport;
|
|
75
75
|
|
|
76
|
+
#[cfg(feature = "build")]
|
|
77
|
+
pub mod build;
|
|
78
|
+
|
|
76
79
|
// Re-exports for convenience.
|
|
77
80
|
pub use client::{BoxClient, Client, OpenStream, SrpcClient};
|
|
78
81
|
pub use codec::{PacketCodec, MAX_MESSAGE_SIZE};
|
package/srpc/packet-rw.go
CHANGED
|
@@ -5,14 +5,36 @@ import (
|
|
|
5
5
|
"context"
|
|
6
6
|
"encoding/binary"
|
|
7
7
|
"io"
|
|
8
|
-
"math"
|
|
9
8
|
"sync"
|
|
10
9
|
|
|
11
10
|
"github.com/pkg/errors"
|
|
12
11
|
)
|
|
13
12
|
|
|
14
|
-
|
|
15
|
-
|
|
13
|
+
const (
|
|
14
|
+
// maxMessageSize is the max message size in bytes.
|
|
15
|
+
maxMessageSize = 10_000_000
|
|
16
|
+
// readBufferSize is the packet read scratch buffer size.
|
|
17
|
+
readBufferSize = 2048
|
|
18
|
+
// pooledWriteBufferMaxSize is the largest outbound frame buffer to pool.
|
|
19
|
+
pooledWriteBufferMaxSize = 64 * 1024
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
var (
|
|
23
|
+
readBufferPool = sync.Pool{
|
|
24
|
+
New: func() any {
|
|
25
|
+
return new([readBufferSize]byte)
|
|
26
|
+
},
|
|
27
|
+
}
|
|
28
|
+
writeBufferPool = sync.Pool{
|
|
29
|
+
New: func() any {
|
|
30
|
+
return new(writeBuffer)
|
|
31
|
+
},
|
|
32
|
+
}
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
type writeBuffer struct {
|
|
36
|
+
data []byte
|
|
37
|
+
}
|
|
16
38
|
|
|
17
39
|
// PacketReadWriter reads and writes packets from a io.ReadWriter.
|
|
18
40
|
// Uses a LittleEndian uint32 length prefix.
|
|
@@ -43,13 +65,13 @@ func (r *PacketReadWriter) WritePacket(p *Packet) error {
|
|
|
43
65
|
defer r.writeMtx.Unlock()
|
|
44
66
|
|
|
45
67
|
msgSize := p.SizeVT()
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
if uint64(msgSize) > uint64(math.MaxUint32) {
|
|
49
|
-
return errors.New("message size exceeds maximum uint32 value")
|
|
68
|
+
if msgSize < 0 || msgSize > maxMessageSize {
|
|
69
|
+
return errors.Errorf("message size %v greater than maximum %v", msgSize, maxMessageSize)
|
|
50
70
|
}
|
|
51
71
|
|
|
52
|
-
|
|
72
|
+
writeBuf := getWriteBuffer(4 + msgSize)
|
|
73
|
+
defer putWriteBuffer(writeBuf)
|
|
74
|
+
data := writeBuf.data
|
|
53
75
|
binary.LittleEndian.PutUint32(data, uint32(msgSize)) //nolint:gosec
|
|
54
76
|
|
|
55
77
|
_, err := p.MarshalToSizedBufferVT(data[4:])
|
|
@@ -59,10 +81,13 @@ func (r *PacketReadWriter) WritePacket(p *Packet) error {
|
|
|
59
81
|
|
|
60
82
|
var written, n int
|
|
61
83
|
for written < len(data) {
|
|
62
|
-
n, err = r.rw.Write(data)
|
|
84
|
+
n, err = r.rw.Write(data[written:])
|
|
63
85
|
if err != nil {
|
|
64
86
|
return err
|
|
65
87
|
}
|
|
88
|
+
if n == 0 {
|
|
89
|
+
return io.ErrShortWrite
|
|
90
|
+
}
|
|
66
91
|
written += n
|
|
67
92
|
}
|
|
68
93
|
|
|
@@ -84,7 +109,9 @@ func (r *PacketReadWriter) ReadPump(cb PacketDataHandler, closed CloseHandler) {
|
|
|
84
109
|
// Does not handle closing the stream, use ReadPump instead.
|
|
85
110
|
func (r *PacketReadWriter) ReadToHandler(cb PacketDataHandler) error {
|
|
86
111
|
var currLen uint32
|
|
87
|
-
|
|
112
|
+
bufPtr := readBufferPool.Get().(*[readBufferSize]byte)
|
|
113
|
+
defer readBufferPool.Put(bufPtr)
|
|
114
|
+
buf := bufPtr[:]
|
|
88
115
|
isOpen := true
|
|
89
116
|
|
|
90
117
|
for isOpen {
|
|
@@ -153,5 +180,25 @@ func (r *PacketReadWriter) readLengthPrefix(b []byte) uint32 {
|
|
|
153
180
|
return binary.LittleEndian.Uint32(b)
|
|
154
181
|
}
|
|
155
182
|
|
|
183
|
+
func getWriteBuffer(size int) *writeBuffer {
|
|
184
|
+
if size > pooledWriteBufferMaxSize {
|
|
185
|
+
return &writeBuffer{data: make([]byte, size)}
|
|
186
|
+
}
|
|
187
|
+
buf := writeBufferPool.Get().(*writeBuffer)
|
|
188
|
+
if cap(buf.data) < size {
|
|
189
|
+
buf.data = make([]byte, size)
|
|
190
|
+
}
|
|
191
|
+
buf.data = buf.data[:size]
|
|
192
|
+
return buf
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
func putWriteBuffer(buf *writeBuffer) {
|
|
196
|
+
if cap(buf.data) <= pooledWriteBufferMaxSize {
|
|
197
|
+
clear(buf.data)
|
|
198
|
+
buf.data = buf.data[:0]
|
|
199
|
+
writeBufferPool.Put(buf)
|
|
200
|
+
}
|
|
201
|
+
}
|
|
202
|
+
|
|
156
203
|
// _ is a type assertion
|
|
157
204
|
var _ PacketWriter = (*PacketReadWriter)(nil)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
package srpc
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"bytes"
|
|
5
|
+
"encoding/binary"
|
|
6
|
+
"io"
|
|
7
|
+
"testing"
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
type chunkedReadWriteCloser struct {
|
|
11
|
+
bytes.Buffer
|
|
12
|
+
maxWrite int
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
func (c *chunkedReadWriteCloser) Read([]byte) (int, error) {
|
|
16
|
+
return 0, io.EOF
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
func (c *chunkedReadWriteCloser) Write(p []byte) (int, error) {
|
|
20
|
+
if len(p) > c.maxWrite {
|
|
21
|
+
p = p[:c.maxWrite]
|
|
22
|
+
}
|
|
23
|
+
return c.Buffer.Write(p)
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
func (c *chunkedReadWriteCloser) Close() error {
|
|
27
|
+
return nil
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
func TestPacketReadWriterWritePacketHandlesShortWrites(t *testing.T) {
|
|
31
|
+
pkt := NewCallDataPacket([]byte("packet payload"), false, true, nil)
|
|
32
|
+
size := pkt.SizeVT()
|
|
33
|
+
want := make([]byte, 4+size)
|
|
34
|
+
binary.LittleEndian.PutUint32(want, uint32(size)) //nolint:gosec
|
|
35
|
+
if _, err := pkt.MarshalToSizedBufferVT(want[4:]); err != nil {
|
|
36
|
+
t.Fatal(err)
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
rwc := &chunkedReadWriteCloser{maxWrite: 3}
|
|
40
|
+
if err := NewPacketReadWriter(rwc).WritePacket(pkt); err != nil {
|
|
41
|
+
t.Fatal(err)
|
|
42
|
+
}
|
|
43
|
+
if got := rwc.Bytes(); !bytes.Equal(got, want) {
|
|
44
|
+
t.Fatalf("written packet mismatch:\ngot %x\nwant %x", got, want)
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
func TestPacketUnmarshalCopiesByteFields(t *testing.T) {
|
|
49
|
+
want := []byte("stable data")
|
|
50
|
+
srcPkt := NewCallDataPacket(want, false, true, nil)
|
|
51
|
+
data, err := srcPkt.MarshalVT()
|
|
52
|
+
if err != nil {
|
|
53
|
+
t.Fatal(err)
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
var pkt Packet
|
|
57
|
+
if err := pkt.UnmarshalVT(data); err != nil {
|
|
58
|
+
t.Fatal(err)
|
|
59
|
+
}
|
|
60
|
+
for i := range data {
|
|
61
|
+
data[i] = 0xff
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
got := pkt.GetCallData().GetData()
|
|
65
|
+
if !bytes.Equal(got, want) {
|
|
66
|
+
t.Fatalf("unmarshal retained source bytes: got %q want %q", got, want)
|
|
67
|
+
}
|
|
68
|
+
}
|