replit-river 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Repl.it
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,22 @@
1
+ Metadata-Version: 2.1
2
+ Name: replit-river
3
+ Version: 0.1.0
4
+ Summary: Replit river toolkit for Python
5
+ License: LICENSE
6
+ Keywords: rpc,websockets
7
+ Author: Replit
8
+ Author-email: eng@replit.com
9
+ Requires-Python: >=3.11,<4.0
10
+ Classifier: License :: Other/Proprietary License
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.11
13
+ Requires-Dist: aiochannel (>=1.2.1,<2.0.0)
14
+ Requires-Dist: black (>=23.11.0,<24.0.0)
15
+ Requires-Dist: grpcio (>=1.59.3,<2.0.0)
16
+ Requires-Dist: grpcio-tools (>=1.59.3,<2.0.0)
17
+ Requires-Dist: msgpack (>=1.0.7,<2.0.0)
18
+ Requires-Dist: nanoid (>=2.0.0,<3.0.0)
19
+ Requires-Dist: protobuf (>=4.24.4,<5.0.0)
20
+ Requires-Dist: pydantic (>=2.5.2,<3.0.0)
21
+ Requires-Dist: pydantic-core (>=2.16.3,<3.0.0)
22
+ Requires-Dist: websockets (>=12.0,<13.0)
@@ -0,0 +1,75 @@
1
+ [build-system]
2
+ requires = ["poetry-core"]
3
+ build-backend = "poetry.core.masonry.api"
4
+
5
+ [tool.poetry]
6
+ name="replit-river"
7
+ version="0.1.0"
8
+ description="Replit river toolkit for Python"
9
+ authors = ["Replit <eng@replit.com>"]
10
+ license = "LICENSE"
11
+ keywords = ["rpc", "websockets"]
12
+
13
+ [tool.poetry.scripts]
14
+ river-codegen-py = "river.codegen:run"
15
+
16
+ [tool.poetry.dependencies]
17
+ python = "^3.11"
18
+ grpcio = "^1.59.3"
19
+ grpcio-tools = "^1.59.3"
20
+ protobuf = "^4.24.4"
21
+ black = "^23.11.0"
22
+ msgpack = "^1.0.7"
23
+ aiochannel = "^1.2.1"
24
+ nanoid = "^2.0.0"
25
+ pydantic = "^2.5.2"
26
+ websockets = "^12.0"
27
+ pydantic-core = "^2.16.3"
28
+
29
+ [tool.poetry.group.dev.dependencies]
30
+ pytest = "^7.4.0"
31
+ mypy = "^1.4.0"
32
+ black = "^23.3.0"
33
+ pytest-cov = "^4.1.0"
34
+ ruff = "^0.0.278"
35
+ pytest-mock = "^3.11.1"
36
+ pytest-asyncio = "^0.21.1"
37
+ types-protobuf = "^4.24.0.20240311"
38
+ mypy-protobuf = "^3.5.0"
39
+ deptry = "^0.14.0"
40
+
41
+ [tool.ruff]
42
+ select = ["F", "E", "W", "I001"]
43
+
44
+ # Should be kept in sync with mypy.ini in the project root.
45
+ # The VSCode mypy extension can only read /mypy.ini.
46
+ # While mypy run inside the chat container can only see this file.
47
+ [tool.mypy]
48
+ plugins = "pydantic.mypy"
49
+ disallow_untyped_defs = true
50
+ warn_return_any = true
51
+
52
+ [tool.pytest.ini_options]
53
+ asyncio_mode = "auto" # auto-detect async tests/fixtures
54
+ addopts = "--tb=short"
55
+ env = [
56
+ "DD_DOGSTATSD_DISABLE=true",
57
+ "DD_TRACE_ENABLED=false",
58
+ ]
59
+ filterwarnings = [
60
+ "ignore::DeprecationWarning", # google SDKs cause this noise
61
+ ]
62
+ markers = [
63
+ "e2e: marks tests as end-to-end (deselect with '-m \"not e2e\"')",
64
+ ]
65
+
66
+ [[tool.mypy.overrides]]
67
+ module = [
68
+ "google.auth.*",
69
+ "google.oauth2.*",
70
+ "google.cloud.sqlalchemy_spanner.sqlalchemy_spanner.*",
71
+ "grpc.*",
72
+ "grpc_tools.*",
73
+ "nanoid.*",
74
+ ]
75
+ ignore_missing_imports = true
@@ -0,0 +1,21 @@
1
+ from .client import Client
2
+ from .rpc import (
3
+ GenericRpcHandler,
4
+ GrpcContext,
5
+ rpc_method_handler,
6
+ stream_method_handler,
7
+ subscription_method_handler,
8
+ upload_method_handler,
9
+ )
10
+ from .server import Server
11
+
12
+ __all__ = [
13
+ "Client",
14
+ "Server",
15
+ "GrpcContext",
16
+ "GenericRpcHandler",
17
+ "rpc_method_handler",
18
+ "subscription_method_handler",
19
+ "upload_method_handler",
20
+ "stream_method_handler",
21
+ ]
@@ -0,0 +1,453 @@
1
+ import asyncio
2
+ import logging
3
+ from collections.abc import AsyncIterable, AsyncIterator
4
+ from typing import Any, Callable, Dict, Optional, Union
5
+
6
+ import msgpack # type: ignore
7
+ import nanoid # type: ignore
8
+ from aiochannel import Channel
9
+ from pydantic import ValidationError
10
+ from river.error_schema import RiverException
11
+ from websockets import Data
12
+ from websockets.client import WebSocketClientProtocol
13
+ from websockets.exceptions import ConnectionClosed
14
+
15
+ from .rpc import (
16
+ STREAM_CLOSED_BIT,
17
+ STREAM_OPEN_BIT,
18
+ ControlMessageHandshakeRequest,
19
+ ControlMessageHandshakeResponse,
20
+ ErrorType,
21
+ InitType,
22
+ RequestType,
23
+ ResponseType,
24
+ TransportMessage,
25
+ )
26
+
27
+
28
+ class Client:
29
+ def __init__(self, websockets: WebSocketClientProtocol) -> None:
30
+ self.ws = websockets
31
+ self._tasks = set()
32
+ self._from = nanoid.generate()
33
+ self._streams: Dict[str, Channel[Dict[str, Any]]] = {}
34
+ self._seq = 0
35
+ self._ack = 0
36
+
37
+ task = asyncio.create_task(self._handle_messages())
38
+ self._tasks.add(task)
39
+
40
+ def _handle_messages_callback(task: asyncio.Task) -> None:
41
+ self._tasks.remove(task)
42
+ if task.exception():
43
+ logging.error(
44
+ f"Error in river.client._handle_messages: {task.exception()}"
45
+ )
46
+
47
+ task.add_done_callback(_handle_messages_callback)
48
+
49
+ async def send_close_stream(
50
+ self, service_name: str, procedure_name: str, stream_id: str
51
+ ) -> None:
52
+ # close stream
53
+ msg = TransportMessage(
54
+ id=nanoid.generate(),
55
+ from_=self._from,
56
+ to="SERVER",
57
+ serviceName=service_name,
58
+ procedureName=procedure_name,
59
+ streamId=stream_id,
60
+ controlFlags=STREAM_CLOSED_BIT,
61
+ ack=self._ack,
62
+ seq=self._seq,
63
+ payload={
64
+ "type": "CLOSE",
65
+ },
66
+ )
67
+ await self.ws.send(msgpack.packb(msg.model_dump(by_alias=True), datetime=True))
68
+
69
+ def to_transport_message(self, message: Data) -> TransportMessage:
70
+ unpacked = msgpack.unpackb(message, timestamp=3)
71
+
72
+ return TransportMessage(**unpacked)
73
+
74
+ async def send_transport_message(self, message: TransportMessage) -> None:
75
+ await self.ws.send(
76
+ msgpack.packb(
77
+ message.model_dump(by_alias=True),
78
+ datetime=True,
79
+ )
80
+ )
81
+ self._seq += 1
82
+
83
+ def pack_transport_message(
84
+ self,
85
+ from_: str,
86
+ to: str,
87
+ serviceName: str,
88
+ procedureName: str,
89
+ streamId: str,
90
+ controlFlags: int,
91
+ payload: Dict[str, Any],
92
+ ) -> TransportMessage:
93
+ return TransportMessage(
94
+ id=nanoid.generate(),
95
+ from_=from_,
96
+ to=to,
97
+ serviceName=serviceName,
98
+ procedureName=procedureName,
99
+ streamId=streamId,
100
+ controlFlags=controlFlags,
101
+ payload=payload,
102
+ seq=self._seq,
103
+ ack=self._ack,
104
+ )
105
+
106
+ def generate_nanoid(self) -> str:
107
+ return str(nanoid.generate())
108
+
109
+ async def _handle_messages(self) -> None:
110
+ handshake_request = ControlMessageHandshakeRequest(
111
+ type="HANDSHAKE_REQ",
112
+ protocol_version="v1",
113
+ instance_id="python-client-" + self.generate_nanoid(),
114
+ )
115
+ await self.send_transport_message(
116
+ TransportMessage(
117
+ id=self.generate_nanoid(),
118
+ from_=self._from,
119
+ to="SERVER",
120
+ seq=0,
121
+ ack=0,
122
+ serviceName=None,
123
+ procedureName=None,
124
+ streamId=self.generate_nanoid(),
125
+ controlFlags=0,
126
+ payload=handshake_request.model_dump(),
127
+ )
128
+ )
129
+ first_message = self.to_transport_message(await self.ws.recv())
130
+ try:
131
+ handshake_response = ControlMessageHandshakeResponse(
132
+ **first_message.payload
133
+ )
134
+ except ValidationError:
135
+ logging.error("Failed to parse handshake response")
136
+ # TODO: close the connection here
137
+ return
138
+ if not handshake_response.status["ok"]:
139
+ logging.error(f"Handshake failed: {handshake_response.status['message']}")
140
+ # TODO: close the connection here
141
+ return
142
+
143
+ async for message in self.ws:
144
+ if isinstance(message, str):
145
+ # Not something we will try to handle.
146
+ logging.debug(
147
+ "ignored a message beacuse it was a text frame: %r",
148
+ message,
149
+ )
150
+ continue
151
+ try:
152
+ unpacked = msgpack.unpackb(message, timestamp=3)
153
+
154
+ msg = TransportMessage(**unpacked)
155
+ if msg.seq != self._ack:
156
+ logging.debug(
157
+ "Received out of order message: %d, expected %d",
158
+ msg.seq,
159
+ self._ack,
160
+ )
161
+ continue
162
+ self.ack = msg.seq + 1
163
+ except ConnectionClosed:
164
+ logging.info("Connection closed")
165
+ break
166
+
167
+ except (
168
+ ValidationError,
169
+ ValueError,
170
+ msgpack.UnpackException,
171
+ ):
172
+ logging.exception("failed to parse message")
173
+ return
174
+ previous_output = self._streams.get(msg.streamId, None)
175
+ if not previous_output:
176
+ logging.warning("no stream for %s", msg.streamId)
177
+ continue
178
+ await previous_output.put(msg.payload)
179
+ if msg.controlFlags & STREAM_CLOSED_BIT != 0:
180
+ logging.info("Closing stream %s", msg.streamId)
181
+ previous_output.close()
182
+ del self._streams[msg.streamId]
183
+
184
+ async def send_rpc(
185
+ self,
186
+ service_name: str,
187
+ procedure_name: str,
188
+ request: RequestType,
189
+ request_serializer: Callable[[RequestType], Any],
190
+ response_deserializer: Callable[[Any], ResponseType],
191
+ error_deserializer: Callable[[Any], ErrorType],
192
+ ) -> ResponseType:
193
+ """Sends a single RPC request to the server.
194
+
195
+ Expects the input and output be messages that will be msgpacked.
196
+ """
197
+
198
+ stream_id = nanoid.generate()
199
+ output: Channel[Any] = Channel(1)
200
+ self._streams[stream_id] = output
201
+
202
+ msg = self.pack_transport_message(
203
+ from_=self._from,
204
+ to="SERVER",
205
+ serviceName=service_name,
206
+ procedureName=procedure_name,
207
+ streamId=stream_id,
208
+ controlFlags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT,
209
+ payload=request_serializer(request),
210
+ )
211
+ await self.send_transport_message(msg)
212
+
213
+ # Handle potential errors during communication
214
+ try:
215
+ response = await output.get()
216
+ if response.get("ack", None):
217
+ response = await output.get()
218
+ if not response.get("ok", False):
219
+ try:
220
+ error = error_deserializer(response["payload"])
221
+ except Exception as e:
222
+ raise RiverException("error_deserializer", str(e))
223
+ raise RiverException(error.code, error.message)
224
+ return response_deserializer(response["payload"])
225
+ except RiverException as e:
226
+ raise e
227
+ except Exception as e:
228
+ # Log the error and return an appropriate error response
229
+ logging.exception("Error during RPC communication")
230
+ raise e
231
+
232
+ async def send_upload(
233
+ self,
234
+ service_name: str,
235
+ procedure_name: str,
236
+ init: Optional[InitType],
237
+ request: AsyncIterable[RequestType],
238
+ init_serializer: Optional[Callable[[InitType], Any]],
239
+ request_serializer: Callable[[RequestType], Any],
240
+ response_deserializer: Callable[[Any], ResponseType],
241
+ error_deserializer: Callable[[Any], ErrorType],
242
+ ) -> ResponseType:
243
+ """Sends an upload request to the server.
244
+
245
+ Expects the input and output be messages that will be msgpacked.
246
+ """
247
+
248
+ stream_id = nanoid.generate()
249
+ output: Channel[Any] = Channel(1024)
250
+ self._streams[stream_id] = output
251
+ first_message = True
252
+ num_sent_messages = 0
253
+ if init and init_serializer:
254
+ num_sent_messages += 1
255
+ msg = self.pack_transport_message(
256
+ from_=self._from,
257
+ to="SERVER",
258
+ serviceName=service_name,
259
+ procedureName=procedure_name,
260
+ streamId=stream_id,
261
+ controlFlags=STREAM_OPEN_BIT,
262
+ payload=init_serializer(init),
263
+ )
264
+ await self.send_transport_message(msg)
265
+ first_message = False
266
+
267
+ async for item in request:
268
+ control_flags = 0
269
+ if first_message:
270
+ control_flags = STREAM_OPEN_BIT
271
+ first_message = False
272
+ num_sent_messages += 1
273
+ msg = self.pack_transport_message(
274
+ from_=self._from,
275
+ to="SERVER",
276
+ serviceName=service_name,
277
+ procedureName=procedure_name,
278
+ streamId=stream_id,
279
+ controlFlags=control_flags,
280
+ payload=request_serializer(item),
281
+ )
282
+ await self.send_transport_message(msg)
283
+ num_sent_messages += 1
284
+ await self.send_close_stream(service_name, procedure_name, stream_id)
285
+
286
+ # Handle potential errors during communication
287
+ try:
288
+ for _ in range(num_sent_messages):
289
+ ack_response = await output.get()
290
+ if not ack_response.get("ack", None):
291
+ raise RiverException("ack error", "No ack received")
292
+ response = await output.get()
293
+ if not response.get("ok", False):
294
+ try:
295
+ error = error_deserializer(response["payload"])
296
+ except Exception as e:
297
+ raise RiverException("error_deserializer", str(e))
298
+ raise RiverException(error.code, error.message)
299
+
300
+ return response_deserializer(response["payload"])
301
+ except RiverException as e:
302
+ raise e
303
+ except Exception as e:
304
+ # Log the error and return an appropriate error response
305
+ logging.exception("Error during upload communication")
306
+ raise e
307
+
308
+ async def send_subscription(
309
+ self,
310
+ service_name: str,
311
+ procedure_name: str,
312
+ request: RequestType,
313
+ request_serializer: Callable[[RequestType], Any],
314
+ response_deserializer: Callable[[Any], ResponseType],
315
+ error_deserializer: Callable[[Any], ErrorType],
316
+ ) -> AsyncIterator[Union[ResponseType, ErrorType]]:
317
+ """Sends a subscription request to the server.
318
+
319
+ Expects the input and output be messages that will be msgpacked.
320
+ """
321
+ stream_id = nanoid.generate()
322
+ output: Channel[Any] = Channel(1024)
323
+ self._streams[stream_id] = output
324
+ msg = self.pack_transport_message(
325
+ from_=self._from,
326
+ to="SERVER",
327
+ serviceName=service_name,
328
+ procedureName=procedure_name,
329
+ streamId=stream_id,
330
+ controlFlags=STREAM_OPEN_BIT,
331
+ payload=request_serializer(request),
332
+ )
333
+ await self.send_transport_message(msg)
334
+
335
+ # Handle potential errors during communication
336
+ try:
337
+ ack_response = await output.get()
338
+ if not ack_response.get("ack", None):
339
+ raise RiverException("ack error", "No ack received")
340
+
341
+ async for item in output:
342
+ if "type" in item and item["type"] == "CLOSE":
343
+ break
344
+ if not item.get("ok", False):
345
+ try:
346
+ yield error_deserializer(item["payload"])
347
+ except Exception:
348
+ logging.exception(
349
+ f"Error during subscription error deserialization: {item}"
350
+ )
351
+ continue
352
+ yield response_deserializer(item["payload"])
353
+ except Exception as e:
354
+ # Log the error and yield an appropriate error response
355
+ logging.exception(f"Error during subscription communication : {item}")
356
+ raise e
357
+
358
+ async def send_stream(
359
+ self,
360
+ service_name: str,
361
+ procedure_name: str,
362
+ init: Optional[InitType],
363
+ request: AsyncIterable[RequestType],
364
+ init_serializer: Optional[Callable[[InitType], Any]],
365
+ request_serializer: Callable[[RequestType], Any],
366
+ response_deserializer: Callable[[Any], ResponseType],
367
+ error_deserializer: Callable[[Any], ErrorType],
368
+ ) -> AsyncIterator[Union[ResponseType, ErrorType]]:
369
+ """Sends a subscription request to the server.
370
+
371
+ Expects the input and output be messages that will be msgpacked.
372
+ """
373
+
374
+ stream_id = nanoid.generate()
375
+ output: Channel[Any] = Channel(1024)
376
+ self._streams[stream_id] = output
377
+ num_sent_messages = 0
378
+
379
+ if init and init_serializer:
380
+ num_sent_messages += 1
381
+ msg = self.pack_transport_message(
382
+ from_=self._from,
383
+ to="SERVER",
384
+ serviceName=service_name,
385
+ procedureName=procedure_name,
386
+ streamId=stream_id,
387
+ controlFlags=STREAM_OPEN_BIT,
388
+ payload=init_serializer(init),
389
+ )
390
+ await self.send_transport_message(msg)
391
+ else:
392
+ num_sent_messages += 1
393
+ # Get the very first message to open the stream
394
+ request_iter = aiter(request)
395
+ first = await anext(request_iter)
396
+ msg = self.pack_transport_message(
397
+ from_=self._from,
398
+ to="SERVER",
399
+ serviceName=service_name,
400
+ procedureName=procedure_name,
401
+ streamId=stream_id,
402
+ controlFlags=STREAM_OPEN_BIT,
403
+ payload=request_serializer(first),
404
+ )
405
+ await self.send_transport_message(msg)
406
+
407
+ # Create the encoder task
408
+ async def _encode_stream() -> None:
409
+ async for item in request:
410
+ nonlocal num_sent_messages
411
+ num_sent_messages += 1
412
+ msg = self.pack_transport_message(
413
+ from_=self._from,
414
+ to="SERVER",
415
+ serviceName=service_name,
416
+ procedureName=procedure_name,
417
+ streamId=stream_id,
418
+ controlFlags=0,
419
+ payload=request_serializer(item),
420
+ )
421
+ await self.send_transport_message(msg)
422
+ num_sent_messages += 1
423
+ await self.send_close_stream(service_name, procedure_name, stream_id)
424
+
425
+ task = asyncio.create_task(_encode_stream())
426
+ self._tasks.add(task)
427
+ task.add_done_callback(lambda _: self._tasks.remove(task))
428
+
429
+ for _ in range(num_sent_messages):
430
+ ack_response = await output.get()
431
+ if not ack_response.get("ack", None):
432
+ raise RiverException("ack error", "No ack received")
433
+
434
+ # Handle potential errors during communication
435
+ try:
436
+ async for item in output:
437
+ if "type" in item and item["type"] == "CLOSE":
438
+ # close the stream here
439
+ self._streams[stream_id].close()
440
+ break
441
+ if not item.get("ok", False):
442
+ try:
443
+ yield error_deserializer(item["payload"])
444
+ except Exception:
445
+ logging.exception(
446
+ f"Error during subscription error deserialization: {item}"
447
+ )
448
+ continue
449
+ yield response_deserializer(item["payload"])
450
+ except Exception as e:
451
+ # Log the error and yield an appropriate error response
452
+ logging.exception("Error during stream communication")
453
+ raise e
File without changes
@@ -0,0 +1,8 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # vim: tabstop=4 shiftwidth=4 softtabstop=4
4
+
5
+ from .run import main
6
+
7
+ if __name__ == "__main__":
8
+ main()