truss 0.11.0__py3-none-any.whl → 0.11.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

@@ -359,7 +359,7 @@ def push_chain(
359
359
  "--name",
360
360
  type=str,
361
361
  required=False,
362
- help="Name of the chain to be deployed, if not given, the entrypoint name is used.",
362
+ help="Name of the chain to be watched. If not given, the entrypoint name is used.",
363
363
  )
364
364
  @click.option(
365
365
  "--remote",
@@ -1,14 +1,15 @@
1
1
  import asyncio
2
2
  import logging
3
- from typing import Any, Callable, Dict
3
+ from typing import Any, Callable, Dict, Optional, Protocol
4
4
 
5
5
  import httpx
6
6
  from fastapi import APIRouter, WebSocket
7
7
  from fastapi.responses import JSONResponse, StreamingResponse
8
+ from httpx_ws import AsyncWebSocketSession, WebSocketDisconnect, aconnect_ws
8
9
  from httpx_ws import _exceptions as httpx_ws_exceptions
9
- from httpx_ws import aconnect_ws
10
10
  from starlette.requests import ClientDisconnect, Request
11
11
  from starlette.responses import Response
12
+ from starlette.websockets import WebSocketDisconnect as StartletteWebSocketDisconnect
12
13
  from tenacity import RetryCallState, Retrying, retry_if_exception_type, wait_fixed
13
14
  from wsproto.events import BytesMessage, TextMessage
14
15
 
@@ -30,6 +31,10 @@ BASE_RETRY_EXCEPTIONS = (
30
31
  control_app = APIRouter()
31
32
 
32
33
 
34
+ class CloseableWebsocket(Protocol):
35
+ async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: ...
36
+
37
+
33
38
  @control_app.get("/")
34
39
  def index():
35
40
  return {}
@@ -118,13 +123,75 @@ def inference_retries(
118
123
  yield attempt
119
124
 
120
125
 
121
- async def _safe_close_ws(ws: WebSocket, logger: logging.Logger):
126
+ async def _safe_close_ws(
127
+ ws: CloseableWebsocket,
128
+ logger: logging.Logger,
129
+ code: int = 1000,
130
+ reason: Optional[str] = None,
131
+ ):
122
132
  try:
123
- await ws.close()
133
+ await ws.close(code, reason)
124
134
  except RuntimeError as close_error:
125
135
  logger.debug(f"Duplicate close of websocket: `{close_error}`.")
126
136
 
127
137
 
138
+ async def forward_to_server(
139
+ client_ws: WebSocket, server_ws: AsyncWebSocketSession
140
+ ) -> None:
141
+ while True:
142
+ message = await client_ws.receive()
143
+ if message.get("type") == "websocket.disconnect":
144
+ raise StartletteWebSocketDisconnect(
145
+ message.get("code", 1000), message.get("reason")
146
+ )
147
+ if "text" in message:
148
+ await server_ws.send_text(message["text"])
149
+ elif "bytes" in message:
150
+ await server_ws.send_bytes(message["bytes"])
151
+
152
+
153
+ async def forward_to_client(client_ws: WebSocket, server_ws: AsyncWebSocketSession):
154
+ while True:
155
+ message = await server_ws.receive()
156
+ if message is None:
157
+ break
158
+ if isinstance(message, TextMessage):
159
+ await client_ws.send_text(message.data)
160
+ elif isinstance(message, BytesMessage):
161
+ await client_ws.send_bytes(message.data)
162
+
163
+
164
+ # NB(nikhil): _handle_websocket_forwarding uses some py311 specific syntax, but in newer
165
+ # versions of truss we're guaranteed to be running the control server with at least that version.
166
+ async def _handle_websocket_forwarding(
167
+ client_ws: WebSocket, server_ws: AsyncWebSocketSession
168
+ ):
169
+ logger = client_ws.app.state.logger
170
+ try:
171
+ async with asyncio.TaskGroup() as tg: # type: ignore[attr-defined]
172
+ tg.create_task(forward_to_client(client_ws, server_ws))
173
+ tg.create_task(forward_to_server(client_ws, server_ws))
174
+ except ExceptionGroup as eg: # type: ignore[name-defined] # noqa: F821
175
+ exc = eg.exceptions[0] # NB(nikhil): Only care about the first one.
176
+ if isinstance(exc, WebSocketDisconnect):
177
+ await _safe_close_ws(client_ws, logger, exc.code, exc.reason)
178
+ elif isinstance(exc, StartletteWebSocketDisconnect):
179
+ await _safe_close_ws(server_ws, logger, exc.code, exc.reason)
180
+ else:
181
+ logger.warning(f"Ungraceful websocket close: {exc}")
182
+ finally:
183
+ await _safe_close_ws(client_ws, logger)
184
+ await _safe_close_ws(server_ws, logger)
185
+
186
+
187
+ async def _attempt_websocket_proxy(
188
+ client_ws: WebSocket, proxy_client: httpx.AsyncClient, logger
189
+ ):
190
+ async with aconnect_ws("/v1/websocket", proxy_client) as server_ws: # type: ignore
191
+ await client_ws.accept()
192
+ await _handle_websocket_forwarding(client_ws, server_ws)
193
+
194
+
128
195
  async def proxy_ws(client_ws: WebSocket):
129
196
  proxy_client: httpx.AsyncClient = client_ws.app.state.proxy_client
130
197
  logger = client_ws.app.state.logger
@@ -132,34 +199,7 @@ async def proxy_ws(client_ws: WebSocket):
132
199
  for attempt in inference_retries():
133
200
  with attempt:
134
201
  try:
135
- async with aconnect_ws("/v1/websocket", proxy_client) as server_ws: # type: ignore
136
- # Unfortunate, but FastAPI and httpx-ws have slightly different abstractions
137
- # for sending data, so it's not easy to create a unified wrapper.
138
- async def forward_to_server():
139
- while True:
140
- message = await client_ws.receive()
141
- if message.get("type") == "websocket.disconnect":
142
- break
143
- if "text" in message:
144
- await server_ws.send_text(message["text"])
145
- elif "bytes" in message:
146
- await server_ws.send_bytes(message["bytes"])
147
-
148
- async def forward_to_client():
149
- while True:
150
- message = await server_ws.receive()
151
- if message is None:
152
- break
153
- if isinstance(message, TextMessage):
154
- await client_ws.send_text(message.data)
155
- elif isinstance(message, BytesMessage):
156
- await client_ws.send_bytes(message.data)
157
-
158
- await client_ws.accept()
159
- try:
160
- await asyncio.gather(forward_to_client(), forward_to_server())
161
- finally:
162
- await _safe_close_ws(client_ws, logger)
202
+ await _attempt_websocket_proxy(client_ws, proxy_client, logger)
163
203
  except httpx_ws_exceptions.HTTPXWSException as e:
164
204
  logger.warning(f"WebSocket connection rejected: {e}")
165
205
  await _safe_close_ws(client_ws, logger)
@@ -76,7 +76,7 @@ async def parse_body(request: Request) -> bytes:
76
76
 
77
77
 
78
78
  async def _safe_close_websocket(
79
- ws: WebSocket, reason: Optional[str], status_code: int = 1000
79
+ ws: WebSocket, status_code: int = 1000, reason: Optional[str] = None
80
80
  ) -> None:
81
81
  try:
82
82
  await ws.close(code=status_code, reason=reason)
@@ -257,14 +257,14 @@ class BasetenEndpoints:
257
257
  try:
258
258
  await ws.accept()
259
259
  await self._model.websocket(ws)
260
- await _safe_close_websocket(ws, None, status_code=1000)
260
+ await _safe_close_websocket(ws, status_code=1000, reason=None)
261
261
  except WebSocketDisconnect as ws_error:
262
262
  logging.info(
263
263
  f"Client terminated websocket connection: `{ws_error}`."
264
264
  )
265
265
  except Exception:
266
266
  await _safe_close_websocket(
267
- ws, errors.MODEL_ERROR_MESSAGE, status_code=1011
267
+ ws, status_code=1011, reason=errors.MODEL_ERROR_MESSAGE
268
268
  )
269
269
  raise # Re raise to let `intercept_exceptions` deal with it.
270
270
 
@@ -1,4 +1,5 @@
1
- from unittest.mock import AsyncMock, MagicMock, patch
1
+ import asyncio
2
+ from unittest.mock import AsyncMock, MagicMock, call, patch
2
3
 
3
4
  import pytest
4
5
  from fastapi import FastAPI, WebSocket
@@ -31,33 +32,40 @@ def client_ws(app):
31
32
 
32
33
  @pytest.mark.asyncio
33
34
  async def test_proxy_ws_bidirectional_messaging(client_ws):
34
- """Test that both directions of communication work and clean up properly"""
35
- client_ws.receive.side_effect = [
36
- {"type": "websocket.receive", "text": "msg1"},
37
- {"type": "websocket.receive", "text": "msg2"},
38
- {"type": "websocket.disconnect"},
39
- ]
35
+ client_queue = asyncio.Queue()
36
+ client_ws.receive = client_queue.get
40
37
 
38
+ server_queue = asyncio.Queue()
41
39
  mock_server_ws = AsyncMock(spec=AsyncWebSocketSession)
42
- mock_server_ws.receive.side_effect = [
43
- TextMessage(data="response1"),
44
- TextMessage(data="response2"),
45
- None, # server closing connection
46
- ]
40
+ mock_server_ws.receive = server_queue.get
47
41
  mock_server_ws.__aenter__.return_value = mock_server_ws
48
42
  mock_server_ws.__aexit__.return_value = None
49
43
 
44
+ client_queue.put_nowait({"type": "websocket.receive", "text": "msg1"})
45
+ client_queue.put_nowait({"type": "websocket.receive", "text": "msg2"})
46
+ server_queue.put_nowait(TextMessage(data="response1"))
47
+ server_queue.put_nowait(TextMessage(data="response2"))
48
+
50
49
  with patch(
51
50
  "truss.templates.control.control.endpoints.aconnect_ws",
52
51
  return_value=mock_server_ws,
53
52
  ):
54
- await proxy_ws(client_ws)
53
+ proxy_task = asyncio.create_task(proxy_ws(client_ws))
54
+ await asyncio.sleep(0.5)
55
+
56
+ client_queue.put_nowait(
57
+ {"type": "websocket.disconnect", "code": 1002, "reason": "test-closure"}
58
+ )
59
+
60
+ await proxy_task
55
61
 
56
62
  assert mock_server_ws.send_text.call_count == 2
57
63
  assert mock_server_ws.send_text.call_args_list == [(("msg1",),), (("msg2",),)]
58
64
  assert client_ws.send_text.call_count == 2
59
65
  assert client_ws.send_text.call_args_list == [(("response1",),), (("response2",),)]
60
- client_ws.close.assert_called_once()
66
+
67
+ assert mock_server_ws.close.call_args_list[0] == call(1002, "test-closure")
68
+ client_ws.close.assert_called()
61
69
 
62
70
 
63
71
  @pytest.mark.asyncio
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: truss
3
- Version: 0.11.0
3
+ Version: 0.11.1rc1
4
4
  Summary: A seamless bridge from model development to model delivery
5
5
  Project-URL: Repository, https://github.com/basetenlabs/truss
6
6
  Project-URL: Homepage, https://truss.baseten.co
@@ -8,7 +8,7 @@ truss/base/errors.py,sha256=zDVLEvseTChdPP0oNhBBQCtQUtZJUaof5zeWMIjqz6o,691
8
8
  truss/base/trt_llm_config.py,sha256=CRz3AqGDAyv8YpcBWXUrnfjvNAauyo3yf8ZOGVsSt6g,32782
9
9
  truss/base/truss_config.py,sha256=7CtiJIwMHtDU8Wzn8UTJUVVunD0pWFl4QUVycK2aIpY,28055
10
10
  truss/base/truss_spec.py,sha256=jFVF79CXoEEspl2kXBAPyi-rwISReIGTdobGpaIhwJw,5979
11
- truss/cli/chains_commands.py,sha256=y6pdIAGCcKOPG9bPuCXPfSA0onQm5x-tT_3blSBfPYg,16971
11
+ truss/cli/chains_commands.py,sha256=bqOXQ-0RPS66vSP_OPQdJ5dvctGiVrsGoSUMbURGdSI,16970
12
12
  truss/cli/cli.py,sha256=PaMkuwXZflkU7sa1tEoT_Zmy-iBkEZs1m4IVqcieaeo,30367
13
13
  truss/cli/remote_cli.py,sha256=G_xCKRXzgkCmkiZJhUFfsv5YSVgde1jLA5LPQitpZgI,1905
14
14
  truss/cli/train_commands.py,sha256=GDye7yXGL_nQvXAlY5MWsdj5x0zYOvcQw0Ubn14TiRU,14365
@@ -74,7 +74,7 @@ truss/templates/docker_server_requirements.txt,sha256=PyhOPKAmKW1N2vLvTfLMwsEtuG
74
74
  truss/templates/server.Dockerfile.jinja,sha256=CUYnF_hgxPGq2re7__0UPWlwzOHMoFkxp6NVKi3U16s,7071
75
75
  truss/templates/control/requirements.txt,sha256=Kk0tYID7trPk5gwX38Wrt2-YGWZAXFJCJRcqJ8ZzCjc,251
76
76
  truss/templates/control/control/application.py,sha256=jYeta6hWe1SkfLL3W4IDmdYjg3ZuKqI_UagWYs5RB_E,3793
77
- truss/templates/control/control/endpoints.py,sha256=FM-sgao7I3gMoUTasM3Xq_g2LDoJQe75JxIoaQxzeNo,10031
77
+ truss/templates/control/control/endpoints.py,sha256=z8OJVBdAlJRl4mAdACVgYGGXaSx8Z8CR-_IVu70IidA,11259
78
78
  truss/templates/control/control/server.py,sha256=R4Y219i1dcz0kkksN8obLoX-YXWGo9iW1igindyG50c,3128
79
79
  truss/templates/control/control/helpers/context_managers.py,sha256=W6dyFgLBhPa5meqrOb3w_phMtKfaJI-GhwUfpiycDc8,413
80
80
  truss/templates/control/control/helpers/custom_types.py,sha256=n_lTudtLTpy4oPV3aDdJ4X2rh3KCV5btYO9UnTeUouQ,5471
@@ -97,7 +97,7 @@ truss/templates/server/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3
97
97
  truss/templates/server/main.py,sha256=kWXrdD8z8IpamyWxc8qcvd5ck9gM1Kz2QH5qHJCnmOQ,222
98
98
  truss/templates/server/model_wrapper.py,sha256=k75VVISwwlsx5EGb82UZsu8kCM_i6Yi3-Hd0-Kpm1yo,42055
99
99
  truss/templates/server/requirements.txt,sha256=XblmpfxAmRo3X1V_9oMj8yjdpZ5Wk-C2oa3z6nq4OGw,672
100
- truss/templates/server/truss_server.py,sha256=ob_nceeGtFPZzKKdk_ZZGLoZrJOGE6hR52xM1sPR97A,19498
100
+ truss/templates/server/truss_server.py,sha256=m1DdCBqT80p8Sfft60vpSeH1ZNoUVpTdYPIcqkIE8CA,19519
101
101
  truss/templates/server/common/__init__.py,sha256=qHIqr68L5Tn4mV6S-PbORpcuJ4jmtBR8aCuRTIWDvNo,85
102
102
  truss/templates/server/common/errors.py,sha256=qWeZlmNI8ZGbZbOIp_mtS6IKvUFIzhj3QH8zp-xTp9o,8554
103
103
  truss/templates/server/common/patches.py,sha256=uEOzvDnXsHOkTSa8zygGYuR4GHhrFNVHNQc5peJcwvo,1393
@@ -161,7 +161,7 @@ truss/tests/remote/baseten/test_auth.py,sha256=ttu4bDnmwGfo3oiNut4HVGnh-QnjAefwZ
161
161
  truss/tests/remote/baseten/test_core.py,sha256=6NzJTDmoSUv6Muy1LFEYIUg10-cqw-hbLyeTSWcdNjY,26117
162
162
  truss/tests/remote/baseten/test_remote.py,sha256=y1qSPL1t7dBeYI3xMFn436fttG7wkYdAoENTz7qKObg,23634
163
163
  truss/tests/remote/baseten/test_service.py,sha256=ufZbtQlBNIzFCxRt_iE-APLpWbVw_3ViUpSh6H9W5nU,1945
164
- truss/tests/templates/control/control/test_endpoints.py,sha256=tGU3w8zOKC8LfWGdhp-TlV7E603KXg2xGwpqDdf8Pnw,3385
164
+ truss/tests/templates/control/control/test_endpoints.py,sha256=wUC24DvQPgYbYmIxUIFvRtNjlATAoO-1r9XY38iidLI,3678
165
165
  truss/tests/templates/control/control/test_server.py,sha256=r1O3VEK9eoIL2-cg8nYLXYct_H3jf5rGp1wLT1KBdeA,9488
166
166
  truss/tests/templates/control/control/test_server_integration.py,sha256=EdDY3nLzjrRCJ5LI5yZsNCEImSRkxTL7Rn9mGnK67zA,11837
167
167
  truss/tests/templates/control/control/helpers/test_context_managers.py,sha256=3LoonRaKu_UvhaWs1eNmEQCZq-iJ3aIjI0Mn4amC8Bw,283
@@ -345,27 +345,27 @@ truss_chains/__init__.py,sha256=QDw1YwdqMaQpz5Oltu2Eq2vzEX9fDrMoqnhtbeh60i4,1278
345
345
  truss_chains/framework.py,sha256=CS7tSegPe2Q8UUT6CDkrtSrB3utr_1QN1jTEPjrj5Ug,67519
346
346
  truss_chains/private_types.py,sha256=6CaQEPawFLXjEbJ-01lqfexJtUIekF_q61LNENWegFo,8917
347
347
  truss_chains/public_api.py,sha256=0AXV6UdZIFAMycUNG_klgo4aLFmBZeKGfrulZEWzR0M,9532
348
- truss_chains/public_types.py,sha256=q8Oet6MpECW1FhWW25SCExpZhmk4cFmEsqrO30oZIMw,29112
348
+ truss_chains/public_types.py,sha256=8cBstm5m_9aN_U2WoSHQ9s2z6gTdSEl-nVgXyneU3nY,29166
349
349
  truss_chains/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
350
350
  truss_chains/pydantic_numpy.py,sha256=MG8Ji_Inwo_JSfM2n7TPj8B-nbrBlDYsY3SOeBwD8fE,4289
351
351
  truss_chains/streaming.py,sha256=DGl2LEAN67YwP7Nn9MK488KmYc4KopWmcHuE6WjyO1Q,12521
352
352
  truss_chains/utils.py,sha256=LvpCG2lnN6dqPqyX3PwLH9tyjUzqQN3N4WeEFROMHak,6291
353
353
  truss_chains/deployment/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
354
354
  truss_chains/deployment/code_gen.py,sha256=AmAUZ3h1hP3uYkl3J6o096K5RFLuBOP7kOFSnFC_C4U,32568
355
- truss_chains/deployment/deployment_client.py,sha256=7ckyh71wYVXBQtKfy4gh_MddMQWTNVpOL_FsqnaKFPo,32811
355
+ truss_chains/deployment/deployment_client.py,sha256=haFiVmQek42ewlN_YflBaRDQT4ZYbmT20tvvJOkcUX0,32899
356
356
  truss_chains/reference_code/reference_chainlet.py,sha256=5feSeqGtrHDbldkfZCfX2R5YbbW0Uhc35mhaP2pXrHw,1340
357
357
  truss_chains/reference_code/reference_model.py,sha256=emH3hb23E_nbP98I37PGp1Xk1hz3g3lQ00tiLo55cSM,322
358
358
  truss_chains/remote_chainlet/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
359
359
  truss_chains/remote_chainlet/model_skeleton.py,sha256=8ZReLOO2MLcdg7bNZ61C-6j-e68i2Z-fFlyV3sz0qH8,2376
360
360
  truss_chains/remote_chainlet/stub.py,sha256=Y2gDUzMY9WRaQNHIz-o4dfLUfFyYV9dUhIRQcfgrY8g,17209
361
- truss_chains/remote_chainlet/utils.py,sha256=O_5P-VAUvg0cegEW1uKCOf5EBwD8rEGYVoGMivOmc7k,22374
361
+ truss_chains/remote_chainlet/utils.py,sha256=DR04xFbAoGGCTiKG0ROqkzoWWwTp6aiiH3VduWxJelE,22644
362
362
  truss_train/__init__.py,sha256=7hE6j6-u6UGzCGaNp3CsCN0kAVjBus1Ekups-Bk0fi4,837
363
363
  truss_train/definitions.py,sha256=V985HhY4rdXL10DZxpFEpze9ScxzWErMht4WwaPknGU,6789
364
364
  truss_train/deployment.py,sha256=lWWANSuzBWu2M4oK4qD7n-oVR1JKdmw2Pn5BJQHg-Ck,3074
365
365
  truss_train/loader.py,sha256=0o66EjBaHc2YY4syxxHVR4ordJWs13lNXnKjKq2wq0U,1630
366
366
  truss_train/public_api.py,sha256=9N_NstiUlmBuLUwH_fNG_1x7OhGCytZLNvqKXBlStrM,1220
367
- truss-0.11.0.dist-info/METADATA,sha256=xoZ_Knb3pqkfrWKv6bNX21VdkYMWHZ3AWTig-w395Es,6669
368
- truss-0.11.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
369
- truss-0.11.0.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
370
- truss-0.11.0.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
371
- truss-0.11.0.dist-info/RECORD,,
367
+ truss-0.11.1rc1.dist-info/METADATA,sha256=Wwt0gl5KmgZp_yyELNdWsVKlR_KO7HOPEKwFzHZeDFM,6672
368
+ truss-0.11.1rc1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
369
+ truss-0.11.1rc1.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
370
+ truss-0.11.1rc1.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
371
+ truss-0.11.1rc1.dist-info/RECORD,,
@@ -655,7 +655,9 @@ class _Watcher:
655
655
  with framework.ChainletImporter.import_target(
656
656
  source, entrypoint
657
657
  ) as entrypoint_cls:
658
- self._deployed_chain_name = name or entrypoint_cls.__name__
658
+ self._deployed_chain_name = (
659
+ name or entrypoint_cls.meta_data.chain_name or entrypoint_cls.__name__
660
+ )
659
661
  self._chain_root = _get_chain_root(entrypoint_cls)
660
662
  chainlet_names = set(
661
663
  desc.display_name
@@ -677,7 +679,7 @@ class _Watcher:
677
679
  )
678
680
  if not chain_id:
679
681
  raise public_types.ChainsDeploymentError(
680
- f"Chain `{chain_id}` was not found."
682
+ f"Chain `{self._deployed_chain_name}` was not found."
681
683
  )
682
684
  self._status_page_url = b10_service.URLConfig.status_page_url(
683
685
  self._remote_provider.remote_url, b10_service.URLConfig.CHAIN, chain_id
@@ -473,6 +473,7 @@ class WebSocketProtocol(Protocol):
473
473
 
474
474
  async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: ...
475
475
 
476
+ async def receive(self) -> Union[str, bytes]: ...
476
477
  async def receive_text(self) -> str: ...
477
478
  async def receive_bytes(self) -> bytes: ...
478
479
  async def receive_json(self) -> Any: ...
@@ -11,6 +11,7 @@ import textwrap
11
11
  import threading
12
12
  import time
13
13
  import traceback
14
+ import typing
14
15
  from collections.abc import AsyncIterator
15
16
  from typing import (
16
17
  TYPE_CHECKING,
@@ -586,6 +587,13 @@ class WebsocketWrapperFastAPI:
586
587
  async def close(self, code: int = 1000, reason: Optional[str] = None) -> None:
587
588
  await self._websocket.close(code=code, reason=reason)
588
589
 
590
+ async def receive(self) -> Union[str, bytes]:
591
+ message = await self._websocket.receive()
592
+ if message.get("text"):
593
+ return typing.cast(str, message["text"])
594
+ else:
595
+ return typing.cast(bytes, message["bytes"])
596
+
589
597
  async def receive_text(self) -> str:
590
598
  return await self._websocket.receive_text()
591
599