replit-river 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replit_river-0.1.0/LICENSE +21 -0
- replit_river-0.1.0/PKG-INFO +22 -0
- replit_river-0.1.0/pyproject.toml +75 -0
- replit_river-0.1.0/replit_river/__init__.py +21 -0
- replit_river-0.1.0/replit_river/client.py +453 -0
- replit_river-0.1.0/replit_river/codegen/__init__.py +0 -0
- replit_river-0.1.0/replit_river/codegen/__main__.py +8 -0
- replit_river-0.1.0/replit_river/codegen/client.py +329 -0
- replit_river-0.1.0/replit_river/codegen/run.py +46 -0
- replit_river-0.1.0/replit_river/codegen/schema.py +148 -0
- replit_river-0.1.0/replit_river/codegen/server.py +327 -0
- replit_river-0.1.0/replit_river/error_schema.py +19 -0
- replit_river-0.1.0/replit_river/py.typed +0 -0
- replit_river-0.1.0/replit_river/rpc.py +375 -0
- replit_river-0.1.0/replit_river/server.py +218 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2023 Repl.it
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: replit-river
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Replit river toolkit for Python
|
|
5
|
+
License: LICENSE
|
|
6
|
+
Keywords: rpc,websockets
|
|
7
|
+
Author: Replit
|
|
8
|
+
Author-email: eng@replit.com
|
|
9
|
+
Requires-Python: >=3.11,<4.0
|
|
10
|
+
Classifier: License :: Other/Proprietary License
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
+
Requires-Dist: aiochannel (>=1.2.1,<2.0.0)
|
|
14
|
+
Requires-Dist: black (>=23.11.0,<24.0.0)
|
|
15
|
+
Requires-Dist: grpcio (>=1.59.3,<2.0.0)
|
|
16
|
+
Requires-Dist: grpcio-tools (>=1.59.3,<2.0.0)
|
|
17
|
+
Requires-Dist: msgpack (>=1.0.7,<2.0.0)
|
|
18
|
+
Requires-Dist: nanoid (>=2.0.0,<3.0.0)
|
|
19
|
+
Requires-Dist: protobuf (>=4.24.4,<5.0.0)
|
|
20
|
+
Requires-Dist: pydantic (>=2.5.2,<3.0.0)
|
|
21
|
+
Requires-Dist: pydantic-core (>=2.16.3,<3.0.0)
|
|
22
|
+
Requires-Dist: websockets (>=12.0,<13.0)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["poetry-core"]
|
|
3
|
+
build-backend = "poetry.core.masonry.api"
|
|
4
|
+
|
|
5
|
+
[tool.poetry]
|
|
6
|
+
name="replit-river"
|
|
7
|
+
version="0.1.0"
|
|
8
|
+
description="Replit river toolkit for Python"
|
|
9
|
+
authors = ["Replit <eng@replit.com>"]
|
|
10
|
+
license = "LICENSE"
|
|
11
|
+
keywords = ["rpc", "websockets"]
|
|
12
|
+
|
|
13
|
+
[tool.poetry.scripts]
|
|
14
|
+
river-codegen-py = "river.codegen:run"
|
|
15
|
+
|
|
16
|
+
[tool.poetry.dependencies]
|
|
17
|
+
python = "^3.11"
|
|
18
|
+
grpcio = "^1.59.3"
|
|
19
|
+
grpcio-tools = "^1.59.3"
|
|
20
|
+
protobuf = "^4.24.4"
|
|
21
|
+
black = "^23.11.0"
|
|
22
|
+
msgpack = "^1.0.7"
|
|
23
|
+
aiochannel = "^1.2.1"
|
|
24
|
+
nanoid = "^2.0.0"
|
|
25
|
+
pydantic = "^2.5.2"
|
|
26
|
+
websockets = "^12.0"
|
|
27
|
+
pydantic-core = "^2.16.3"
|
|
28
|
+
|
|
29
|
+
[tool.poetry.group.dev.dependencies]
|
|
30
|
+
pytest = "^7.4.0"
|
|
31
|
+
mypy = "^1.4.0"
|
|
32
|
+
black = "^23.3.0"
|
|
33
|
+
pytest-cov = "^4.1.0"
|
|
34
|
+
ruff = "^0.0.278"
|
|
35
|
+
pytest-mock = "^3.11.1"
|
|
36
|
+
pytest-asyncio = "^0.21.1"
|
|
37
|
+
types-protobuf = "^4.24.0.20240311"
|
|
38
|
+
mypy-protobuf = "^3.5.0"
|
|
39
|
+
deptry = "^0.14.0"
|
|
40
|
+
|
|
41
|
+
[tool.ruff]
|
|
42
|
+
select = ["F", "E", "W", "I001"]
|
|
43
|
+
|
|
44
|
+
# Should be kept in sync with mypy.ini in the project root.
|
|
45
|
+
# The VSCode mypy extension can only read /mypy.ini.
|
|
46
|
+
# While mypy run inside the chat container can only see this file.
|
|
47
|
+
[tool.mypy]
|
|
48
|
+
plugins = "pydantic.mypy"
|
|
49
|
+
disallow_untyped_defs = true
|
|
50
|
+
warn_return_any = true
|
|
51
|
+
|
|
52
|
+
[tool.pytest.ini_options]
|
|
53
|
+
asyncio_mode = "auto" # auto-detect async tests/fixtures
|
|
54
|
+
addopts = "--tb=short"
|
|
55
|
+
env = [
|
|
56
|
+
"DD_DOGSTATSD_DISABLE=true",
|
|
57
|
+
"DD_TRACE_ENABLED=false",
|
|
58
|
+
]
|
|
59
|
+
filterwarnings = [
|
|
60
|
+
"ignore::DeprecationWarning", # google SDKs cause this noise
|
|
61
|
+
]
|
|
62
|
+
markers = [
|
|
63
|
+
"e2e: marks tests as end-to-end (deselect with '-m \"not e2e\"')",
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
[[tool.mypy.overrides]]
|
|
67
|
+
module = [
|
|
68
|
+
"google.auth.*",
|
|
69
|
+
"google.oauth2.*",
|
|
70
|
+
"google.cloud.sqlalchemy_spanner.sqlalchemy_spanner.*",
|
|
71
|
+
"grpc.*",
|
|
72
|
+
"grpc_tools.*",
|
|
73
|
+
"nanoid.*",
|
|
74
|
+
]
|
|
75
|
+
ignore_missing_imports = true
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from .client import Client
|
|
2
|
+
from .rpc import (
|
|
3
|
+
GenericRpcHandler,
|
|
4
|
+
GrpcContext,
|
|
5
|
+
rpc_method_handler,
|
|
6
|
+
stream_method_handler,
|
|
7
|
+
subscription_method_handler,
|
|
8
|
+
upload_method_handler,
|
|
9
|
+
)
|
|
10
|
+
from .server import Server
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"Client",
|
|
14
|
+
"Server",
|
|
15
|
+
"GrpcContext",
|
|
16
|
+
"GenericRpcHandler",
|
|
17
|
+
"rpc_method_handler",
|
|
18
|
+
"subscription_method_handler",
|
|
19
|
+
"upload_method_handler",
|
|
20
|
+
"stream_method_handler",
|
|
21
|
+
]
|
|
@@ -0,0 +1,453 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from collections.abc import AsyncIterable, AsyncIterator
|
|
4
|
+
from typing import Any, Callable, Dict, Optional, Union
|
|
5
|
+
|
|
6
|
+
import msgpack # type: ignore
|
|
7
|
+
import nanoid # type: ignore
|
|
8
|
+
from aiochannel import Channel
|
|
9
|
+
from pydantic import ValidationError
|
|
10
|
+
from river.error_schema import RiverException
|
|
11
|
+
from websockets import Data
|
|
12
|
+
from websockets.client import WebSocketClientProtocol
|
|
13
|
+
from websockets.exceptions import ConnectionClosed
|
|
14
|
+
|
|
15
|
+
from .rpc import (
|
|
16
|
+
STREAM_CLOSED_BIT,
|
|
17
|
+
STREAM_OPEN_BIT,
|
|
18
|
+
ControlMessageHandshakeRequest,
|
|
19
|
+
ControlMessageHandshakeResponse,
|
|
20
|
+
ErrorType,
|
|
21
|
+
InitType,
|
|
22
|
+
RequestType,
|
|
23
|
+
ResponseType,
|
|
24
|
+
TransportMessage,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Client:
|
|
29
|
+
def __init__(self, websockets: WebSocketClientProtocol) -> None:
|
|
30
|
+
self.ws = websockets
|
|
31
|
+
self._tasks = set()
|
|
32
|
+
self._from = nanoid.generate()
|
|
33
|
+
self._streams: Dict[str, Channel[Dict[str, Any]]] = {}
|
|
34
|
+
self._seq = 0
|
|
35
|
+
self._ack = 0
|
|
36
|
+
|
|
37
|
+
task = asyncio.create_task(self._handle_messages())
|
|
38
|
+
self._tasks.add(task)
|
|
39
|
+
|
|
40
|
+
def _handle_messages_callback(task: asyncio.Task) -> None:
|
|
41
|
+
self._tasks.remove(task)
|
|
42
|
+
if task.exception():
|
|
43
|
+
logging.error(
|
|
44
|
+
f"Error in river.client._handle_messages: {task.exception()}"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
task.add_done_callback(_handle_messages_callback)
|
|
48
|
+
|
|
49
|
+
async def send_close_stream(
|
|
50
|
+
self, service_name: str, procedure_name: str, stream_id: str
|
|
51
|
+
) -> None:
|
|
52
|
+
# close stream
|
|
53
|
+
msg = TransportMessage(
|
|
54
|
+
id=nanoid.generate(),
|
|
55
|
+
from_=self._from,
|
|
56
|
+
to="SERVER",
|
|
57
|
+
serviceName=service_name,
|
|
58
|
+
procedureName=procedure_name,
|
|
59
|
+
streamId=stream_id,
|
|
60
|
+
controlFlags=STREAM_CLOSED_BIT,
|
|
61
|
+
ack=self._ack,
|
|
62
|
+
seq=self._seq,
|
|
63
|
+
payload={
|
|
64
|
+
"type": "CLOSE",
|
|
65
|
+
},
|
|
66
|
+
)
|
|
67
|
+
await self.ws.send(msgpack.packb(msg.model_dump(by_alias=True), datetime=True))
|
|
68
|
+
|
|
69
|
+
def to_transport_message(self, message: Data) -> TransportMessage:
|
|
70
|
+
unpacked = msgpack.unpackb(message, timestamp=3)
|
|
71
|
+
|
|
72
|
+
return TransportMessage(**unpacked)
|
|
73
|
+
|
|
74
|
+
async def send_transport_message(self, message: TransportMessage) -> None:
|
|
75
|
+
await self.ws.send(
|
|
76
|
+
msgpack.packb(
|
|
77
|
+
message.model_dump(by_alias=True),
|
|
78
|
+
datetime=True,
|
|
79
|
+
)
|
|
80
|
+
)
|
|
81
|
+
self._seq += 1
|
|
82
|
+
|
|
83
|
+
def pack_transport_message(
|
|
84
|
+
self,
|
|
85
|
+
from_: str,
|
|
86
|
+
to: str,
|
|
87
|
+
serviceName: str,
|
|
88
|
+
procedureName: str,
|
|
89
|
+
streamId: str,
|
|
90
|
+
controlFlags: int,
|
|
91
|
+
payload: Dict[str, Any],
|
|
92
|
+
) -> TransportMessage:
|
|
93
|
+
return TransportMessage(
|
|
94
|
+
id=nanoid.generate(),
|
|
95
|
+
from_=from_,
|
|
96
|
+
to=to,
|
|
97
|
+
serviceName=serviceName,
|
|
98
|
+
procedureName=procedureName,
|
|
99
|
+
streamId=streamId,
|
|
100
|
+
controlFlags=controlFlags,
|
|
101
|
+
payload=payload,
|
|
102
|
+
seq=self._seq,
|
|
103
|
+
ack=self._ack,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def generate_nanoid(self) -> str:
|
|
107
|
+
return str(nanoid.generate())
|
|
108
|
+
|
|
109
|
+
async def _handle_messages(self) -> None:
|
|
110
|
+
handshake_request = ControlMessageHandshakeRequest(
|
|
111
|
+
type="HANDSHAKE_REQ",
|
|
112
|
+
protocol_version="v1",
|
|
113
|
+
instance_id="python-client-" + self.generate_nanoid(),
|
|
114
|
+
)
|
|
115
|
+
await self.send_transport_message(
|
|
116
|
+
TransportMessage(
|
|
117
|
+
id=self.generate_nanoid(),
|
|
118
|
+
from_=self._from,
|
|
119
|
+
to="SERVER",
|
|
120
|
+
seq=0,
|
|
121
|
+
ack=0,
|
|
122
|
+
serviceName=None,
|
|
123
|
+
procedureName=None,
|
|
124
|
+
streamId=self.generate_nanoid(),
|
|
125
|
+
controlFlags=0,
|
|
126
|
+
payload=handshake_request.model_dump(),
|
|
127
|
+
)
|
|
128
|
+
)
|
|
129
|
+
first_message = self.to_transport_message(await self.ws.recv())
|
|
130
|
+
try:
|
|
131
|
+
handshake_response = ControlMessageHandshakeResponse(
|
|
132
|
+
**first_message.payload
|
|
133
|
+
)
|
|
134
|
+
except ValidationError:
|
|
135
|
+
logging.error("Failed to parse handshake response")
|
|
136
|
+
# TODO: close the connection here
|
|
137
|
+
return
|
|
138
|
+
if not handshake_response.status["ok"]:
|
|
139
|
+
logging.error(f"Handshake failed: {handshake_response.status['message']}")
|
|
140
|
+
# TODO: close the connection here
|
|
141
|
+
return
|
|
142
|
+
|
|
143
|
+
async for message in self.ws:
|
|
144
|
+
if isinstance(message, str):
|
|
145
|
+
# Not something we will try to handle.
|
|
146
|
+
logging.debug(
|
|
147
|
+
"ignored a message beacuse it was a text frame: %r",
|
|
148
|
+
message,
|
|
149
|
+
)
|
|
150
|
+
continue
|
|
151
|
+
try:
|
|
152
|
+
unpacked = msgpack.unpackb(message, timestamp=3)
|
|
153
|
+
|
|
154
|
+
msg = TransportMessage(**unpacked)
|
|
155
|
+
if msg.seq != self._ack:
|
|
156
|
+
logging.debug(
|
|
157
|
+
"Received out of order message: %d, expected %d",
|
|
158
|
+
msg.seq,
|
|
159
|
+
self._ack,
|
|
160
|
+
)
|
|
161
|
+
continue
|
|
162
|
+
self.ack = msg.seq + 1
|
|
163
|
+
except ConnectionClosed:
|
|
164
|
+
logging.info("Connection closed")
|
|
165
|
+
break
|
|
166
|
+
|
|
167
|
+
except (
|
|
168
|
+
ValidationError,
|
|
169
|
+
ValueError,
|
|
170
|
+
msgpack.UnpackException,
|
|
171
|
+
):
|
|
172
|
+
logging.exception("failed to parse message")
|
|
173
|
+
return
|
|
174
|
+
previous_output = self._streams.get(msg.streamId, None)
|
|
175
|
+
if not previous_output:
|
|
176
|
+
logging.warning("no stream for %s", msg.streamId)
|
|
177
|
+
continue
|
|
178
|
+
await previous_output.put(msg.payload)
|
|
179
|
+
if msg.controlFlags & STREAM_CLOSED_BIT != 0:
|
|
180
|
+
logging.info("Closing stream %s", msg.streamId)
|
|
181
|
+
previous_output.close()
|
|
182
|
+
del self._streams[msg.streamId]
|
|
183
|
+
|
|
184
|
+
async def send_rpc(
|
|
185
|
+
self,
|
|
186
|
+
service_name: str,
|
|
187
|
+
procedure_name: str,
|
|
188
|
+
request: RequestType,
|
|
189
|
+
request_serializer: Callable[[RequestType], Any],
|
|
190
|
+
response_deserializer: Callable[[Any], ResponseType],
|
|
191
|
+
error_deserializer: Callable[[Any], ErrorType],
|
|
192
|
+
) -> ResponseType:
|
|
193
|
+
"""Sends a single RPC request to the server.
|
|
194
|
+
|
|
195
|
+
Expects the input and output be messages that will be msgpacked.
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
stream_id = nanoid.generate()
|
|
199
|
+
output: Channel[Any] = Channel(1)
|
|
200
|
+
self._streams[stream_id] = output
|
|
201
|
+
|
|
202
|
+
msg = self.pack_transport_message(
|
|
203
|
+
from_=self._from,
|
|
204
|
+
to="SERVER",
|
|
205
|
+
serviceName=service_name,
|
|
206
|
+
procedureName=procedure_name,
|
|
207
|
+
streamId=stream_id,
|
|
208
|
+
controlFlags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT,
|
|
209
|
+
payload=request_serializer(request),
|
|
210
|
+
)
|
|
211
|
+
await self.send_transport_message(msg)
|
|
212
|
+
|
|
213
|
+
# Handle potential errors during communication
|
|
214
|
+
try:
|
|
215
|
+
response = await output.get()
|
|
216
|
+
if response.get("ack", None):
|
|
217
|
+
response = await output.get()
|
|
218
|
+
if not response.get("ok", False):
|
|
219
|
+
try:
|
|
220
|
+
error = error_deserializer(response["payload"])
|
|
221
|
+
except Exception as e:
|
|
222
|
+
raise RiverException("error_deserializer", str(e))
|
|
223
|
+
raise RiverException(error.code, error.message)
|
|
224
|
+
return response_deserializer(response["payload"])
|
|
225
|
+
except RiverException as e:
|
|
226
|
+
raise e
|
|
227
|
+
except Exception as e:
|
|
228
|
+
# Log the error and return an appropriate error response
|
|
229
|
+
logging.exception("Error during RPC communication")
|
|
230
|
+
raise e
|
|
231
|
+
|
|
232
|
+
async def send_upload(
|
|
233
|
+
self,
|
|
234
|
+
service_name: str,
|
|
235
|
+
procedure_name: str,
|
|
236
|
+
init: Optional[InitType],
|
|
237
|
+
request: AsyncIterable[RequestType],
|
|
238
|
+
init_serializer: Optional[Callable[[InitType], Any]],
|
|
239
|
+
request_serializer: Callable[[RequestType], Any],
|
|
240
|
+
response_deserializer: Callable[[Any], ResponseType],
|
|
241
|
+
error_deserializer: Callable[[Any], ErrorType],
|
|
242
|
+
) -> ResponseType:
|
|
243
|
+
"""Sends an upload request to the server.
|
|
244
|
+
|
|
245
|
+
Expects the input and output be messages that will be msgpacked.
|
|
246
|
+
"""
|
|
247
|
+
|
|
248
|
+
stream_id = nanoid.generate()
|
|
249
|
+
output: Channel[Any] = Channel(1024)
|
|
250
|
+
self._streams[stream_id] = output
|
|
251
|
+
first_message = True
|
|
252
|
+
num_sent_messages = 0
|
|
253
|
+
if init and init_serializer:
|
|
254
|
+
num_sent_messages += 1
|
|
255
|
+
msg = self.pack_transport_message(
|
|
256
|
+
from_=self._from,
|
|
257
|
+
to="SERVER",
|
|
258
|
+
serviceName=service_name,
|
|
259
|
+
procedureName=procedure_name,
|
|
260
|
+
streamId=stream_id,
|
|
261
|
+
controlFlags=STREAM_OPEN_BIT,
|
|
262
|
+
payload=init_serializer(init),
|
|
263
|
+
)
|
|
264
|
+
await self.send_transport_message(msg)
|
|
265
|
+
first_message = False
|
|
266
|
+
|
|
267
|
+
async for item in request:
|
|
268
|
+
control_flags = 0
|
|
269
|
+
if first_message:
|
|
270
|
+
control_flags = STREAM_OPEN_BIT
|
|
271
|
+
first_message = False
|
|
272
|
+
num_sent_messages += 1
|
|
273
|
+
msg = self.pack_transport_message(
|
|
274
|
+
from_=self._from,
|
|
275
|
+
to="SERVER",
|
|
276
|
+
serviceName=service_name,
|
|
277
|
+
procedureName=procedure_name,
|
|
278
|
+
streamId=stream_id,
|
|
279
|
+
controlFlags=control_flags,
|
|
280
|
+
payload=request_serializer(item),
|
|
281
|
+
)
|
|
282
|
+
await self.send_transport_message(msg)
|
|
283
|
+
num_sent_messages += 1
|
|
284
|
+
await self.send_close_stream(service_name, procedure_name, stream_id)
|
|
285
|
+
|
|
286
|
+
# Handle potential errors during communication
|
|
287
|
+
try:
|
|
288
|
+
for _ in range(num_sent_messages):
|
|
289
|
+
ack_response = await output.get()
|
|
290
|
+
if not ack_response.get("ack", None):
|
|
291
|
+
raise RiverException("ack error", "No ack received")
|
|
292
|
+
response = await output.get()
|
|
293
|
+
if not response.get("ok", False):
|
|
294
|
+
try:
|
|
295
|
+
error = error_deserializer(response["payload"])
|
|
296
|
+
except Exception as e:
|
|
297
|
+
raise RiverException("error_deserializer", str(e))
|
|
298
|
+
raise RiverException(error.code, error.message)
|
|
299
|
+
|
|
300
|
+
return response_deserializer(response["payload"])
|
|
301
|
+
except RiverException as e:
|
|
302
|
+
raise e
|
|
303
|
+
except Exception as e:
|
|
304
|
+
# Log the error and return an appropriate error response
|
|
305
|
+
logging.exception("Error during upload communication")
|
|
306
|
+
raise e
|
|
307
|
+
|
|
308
|
+
async def send_subscription(
|
|
309
|
+
self,
|
|
310
|
+
service_name: str,
|
|
311
|
+
procedure_name: str,
|
|
312
|
+
request: RequestType,
|
|
313
|
+
request_serializer: Callable[[RequestType], Any],
|
|
314
|
+
response_deserializer: Callable[[Any], ResponseType],
|
|
315
|
+
error_deserializer: Callable[[Any], ErrorType],
|
|
316
|
+
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
|
|
317
|
+
"""Sends a subscription request to the server.
|
|
318
|
+
|
|
319
|
+
Expects the input and output be messages that will be msgpacked.
|
|
320
|
+
"""
|
|
321
|
+
stream_id = nanoid.generate()
|
|
322
|
+
output: Channel[Any] = Channel(1024)
|
|
323
|
+
self._streams[stream_id] = output
|
|
324
|
+
msg = self.pack_transport_message(
|
|
325
|
+
from_=self._from,
|
|
326
|
+
to="SERVER",
|
|
327
|
+
serviceName=service_name,
|
|
328
|
+
procedureName=procedure_name,
|
|
329
|
+
streamId=stream_id,
|
|
330
|
+
controlFlags=STREAM_OPEN_BIT,
|
|
331
|
+
payload=request_serializer(request),
|
|
332
|
+
)
|
|
333
|
+
await self.send_transport_message(msg)
|
|
334
|
+
|
|
335
|
+
# Handle potential errors during communication
|
|
336
|
+
try:
|
|
337
|
+
ack_response = await output.get()
|
|
338
|
+
if not ack_response.get("ack", None):
|
|
339
|
+
raise RiverException("ack error", "No ack received")
|
|
340
|
+
|
|
341
|
+
async for item in output:
|
|
342
|
+
if "type" in item and item["type"] == "CLOSE":
|
|
343
|
+
break
|
|
344
|
+
if not item.get("ok", False):
|
|
345
|
+
try:
|
|
346
|
+
yield error_deserializer(item["payload"])
|
|
347
|
+
except Exception:
|
|
348
|
+
logging.exception(
|
|
349
|
+
f"Error during subscription error deserialization: {item}"
|
|
350
|
+
)
|
|
351
|
+
continue
|
|
352
|
+
yield response_deserializer(item["payload"])
|
|
353
|
+
except Exception as e:
|
|
354
|
+
# Log the error and yield an appropriate error response
|
|
355
|
+
logging.exception(f"Error during subscription communication : {item}")
|
|
356
|
+
raise e
|
|
357
|
+
|
|
358
|
+
async def send_stream(
|
|
359
|
+
self,
|
|
360
|
+
service_name: str,
|
|
361
|
+
procedure_name: str,
|
|
362
|
+
init: Optional[InitType],
|
|
363
|
+
request: AsyncIterable[RequestType],
|
|
364
|
+
init_serializer: Optional[Callable[[InitType], Any]],
|
|
365
|
+
request_serializer: Callable[[RequestType], Any],
|
|
366
|
+
response_deserializer: Callable[[Any], ResponseType],
|
|
367
|
+
error_deserializer: Callable[[Any], ErrorType],
|
|
368
|
+
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
|
|
369
|
+
"""Sends a subscription request to the server.
|
|
370
|
+
|
|
371
|
+
Expects the input and output be messages that will be msgpacked.
|
|
372
|
+
"""
|
|
373
|
+
|
|
374
|
+
stream_id = nanoid.generate()
|
|
375
|
+
output: Channel[Any] = Channel(1024)
|
|
376
|
+
self._streams[stream_id] = output
|
|
377
|
+
num_sent_messages = 0
|
|
378
|
+
|
|
379
|
+
if init and init_serializer:
|
|
380
|
+
num_sent_messages += 1
|
|
381
|
+
msg = self.pack_transport_message(
|
|
382
|
+
from_=self._from,
|
|
383
|
+
to="SERVER",
|
|
384
|
+
serviceName=service_name,
|
|
385
|
+
procedureName=procedure_name,
|
|
386
|
+
streamId=stream_id,
|
|
387
|
+
controlFlags=STREAM_OPEN_BIT,
|
|
388
|
+
payload=init_serializer(init),
|
|
389
|
+
)
|
|
390
|
+
await self.send_transport_message(msg)
|
|
391
|
+
else:
|
|
392
|
+
num_sent_messages += 1
|
|
393
|
+
# Get the very first message to open the stream
|
|
394
|
+
request_iter = aiter(request)
|
|
395
|
+
first = await anext(request_iter)
|
|
396
|
+
msg = self.pack_transport_message(
|
|
397
|
+
from_=self._from,
|
|
398
|
+
to="SERVER",
|
|
399
|
+
serviceName=service_name,
|
|
400
|
+
procedureName=procedure_name,
|
|
401
|
+
streamId=stream_id,
|
|
402
|
+
controlFlags=STREAM_OPEN_BIT,
|
|
403
|
+
payload=request_serializer(first),
|
|
404
|
+
)
|
|
405
|
+
await self.send_transport_message(msg)
|
|
406
|
+
|
|
407
|
+
# Create the encoder task
|
|
408
|
+
async def _encode_stream() -> None:
|
|
409
|
+
async for item in request:
|
|
410
|
+
nonlocal num_sent_messages
|
|
411
|
+
num_sent_messages += 1
|
|
412
|
+
msg = self.pack_transport_message(
|
|
413
|
+
from_=self._from,
|
|
414
|
+
to="SERVER",
|
|
415
|
+
serviceName=service_name,
|
|
416
|
+
procedureName=procedure_name,
|
|
417
|
+
streamId=stream_id,
|
|
418
|
+
controlFlags=0,
|
|
419
|
+
payload=request_serializer(item),
|
|
420
|
+
)
|
|
421
|
+
await self.send_transport_message(msg)
|
|
422
|
+
num_sent_messages += 1
|
|
423
|
+
await self.send_close_stream(service_name, procedure_name, stream_id)
|
|
424
|
+
|
|
425
|
+
task = asyncio.create_task(_encode_stream())
|
|
426
|
+
self._tasks.add(task)
|
|
427
|
+
task.add_done_callback(lambda _: self._tasks.remove(task))
|
|
428
|
+
|
|
429
|
+
for _ in range(num_sent_messages):
|
|
430
|
+
ack_response = await output.get()
|
|
431
|
+
if not ack_response.get("ack", None):
|
|
432
|
+
raise RiverException("ack error", "No ack received")
|
|
433
|
+
|
|
434
|
+
# Handle potential errors during communication
|
|
435
|
+
try:
|
|
436
|
+
async for item in output:
|
|
437
|
+
if "type" in item and item["type"] == "CLOSE":
|
|
438
|
+
# close the stream here
|
|
439
|
+
self._streams[stream_id].close()
|
|
440
|
+
break
|
|
441
|
+
if not item.get("ok", False):
|
|
442
|
+
try:
|
|
443
|
+
yield error_deserializer(item["payload"])
|
|
444
|
+
except Exception:
|
|
445
|
+
logging.exception(
|
|
446
|
+
f"Error during subscription error deserialization: {item}"
|
|
447
|
+
)
|
|
448
|
+
continue
|
|
449
|
+
yield response_deserializer(item["payload"])
|
|
450
|
+
except Exception as e:
|
|
451
|
+
# Log the error and yield an appropriate error response
|
|
452
|
+
logging.exception("Error during stream communication")
|
|
453
|
+
raise e
|
|
File without changes
|