truss 0.11.1rc2__py3-none-any.whl → 0.11.1rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

@@ -359,7 +359,7 @@ def push_chain(
359
359
  "--name",
360
360
  type=str,
361
361
  required=False,
362
- help="Name of the chain to be deployed, if not given, the entrypoint name is used.",
362
+ help="Name of the chain to be watched. If not given, the entrypoint name is used.",
363
363
  )
364
364
  @click.option(
365
365
  "--remote",
@@ -1,14 +1,15 @@
1
1
  import asyncio
2
2
  import logging
3
- from typing import Any, Callable, Dict
3
+ from typing import Any, Callable, Dict, Optional, Protocol
4
4
 
5
5
  import httpx
6
6
  from fastapi import APIRouter, WebSocket
7
7
  from fastapi.responses import JSONResponse, StreamingResponse
8
+ from httpx_ws import AsyncWebSocketSession, WebSocketDisconnect, aconnect_ws
8
9
  from httpx_ws import _exceptions as httpx_ws_exceptions
9
- from httpx_ws import aconnect_ws
10
10
  from starlette.requests import ClientDisconnect, Request
11
11
  from starlette.responses import Response
12
+ from starlette.websockets import WebSocketDisconnect as StartletteWebSocketDisconnect
12
13
  from tenacity import RetryCallState, Retrying, retry_if_exception_type, wait_fixed
13
14
  from wsproto.events import BytesMessage, TextMessage
14
15
 
@@ -29,6 +30,15 @@ BASE_RETRY_EXCEPTIONS = (
29
30
 
30
31
  control_app = APIRouter()
31
32
 
33
+ WEBSOCKET_NORMAL_CLOSURE_CODE = 1000
34
+ WEBSOCKET_SERVER_ERROR_CODE = 1011
35
+
36
+
37
+ class CloseableWebsocket(Protocol):
38
+ async def close(
39
+ self, code: int = WEBSOCKET_NORMAL_CLOSURE_CODE, reason: Optional[str] = None
40
+ ) -> None: ...
41
+
32
42
 
33
43
  @control_app.get("/")
34
44
  def index():
@@ -118,13 +128,79 @@ def inference_retries(
118
128
  yield attempt
119
129
 
120
130
 
121
- async def _safe_close_ws(ws: WebSocket, logger: logging.Logger):
131
+ async def _safe_close_ws(
132
+ ws: CloseableWebsocket,
133
+ logger: logging.Logger,
134
+ code: int,
135
+ reason: Optional[str] = None,
136
+ ):
122
137
  try:
123
- await ws.close()
138
+ await ws.close(code, reason)
124
139
  except RuntimeError as close_error:
125
140
  logger.debug(f"Duplicate close of websocket: `{close_error}`.")
126
141
 
127
142
 
143
+ async def forward_to_server(
144
+ client_ws: WebSocket, server_ws: AsyncWebSocketSession
145
+ ) -> None:
146
+ while True:
147
+ message = await client_ws.receive()
148
+ if message.get("type") == "websocket.disconnect":
149
+ raise StartletteWebSocketDisconnect(
150
+ message.get("code", 1000), message.get("reason")
151
+ )
152
+ if "text" in message:
153
+ await server_ws.send_text(message["text"])
154
+ elif "bytes" in message:
155
+ await server_ws.send_bytes(message["bytes"])
156
+
157
+
158
+ async def forward_to_client(client_ws: WebSocket, server_ws: AsyncWebSocketSession):
159
+ while True:
160
+ message = await server_ws.receive()
161
+ if isinstance(message, TextMessage):
162
+ await client_ws.send_text(message.data)
163
+ elif isinstance(message, BytesMessage):
164
+ await client_ws.send_bytes(message.data)
165
+
166
+
167
+ # NB(nikhil): _handle_websocket_forwarding uses some py311 specific syntax, but in newer
168
+ # versions of truss we're guaranteed to be running the control server with at least that version.
169
+ async def _handle_websocket_forwarding(
170
+ client_ws: WebSocket, server_ws: AsyncWebSocketSession
171
+ ):
172
+ logger = client_ws.app.state.logger
173
+ try:
174
+ async with asyncio.TaskGroup() as tg: # type: ignore[attr-defined]
175
+ tg.create_task(forward_to_client(client_ws, server_ws))
176
+ tg.create_task(forward_to_server(client_ws, server_ws))
177
+ except ExceptionGroup as eg: # type: ignore[name-defined] # noqa: F821
178
+ # NB(nikhil): The first websocket proxy method to raise an error will
179
+ # be surfaced here, and that contains the information we want to forward to the
180
+ # other websocket. Further errors might raise as a result of cancellation, but we
181
+ # can safely ignore those.
182
+ exc = eg.exceptions[0]
183
+ if isinstance(exc, WebSocketDisconnect):
184
+ await _safe_close_ws(client_ws, logger, exc.code, exc.reason)
185
+ elif isinstance(exc, StartletteWebSocketDisconnect):
186
+ await _safe_close_ws(server_ws, logger, exc.code, exc.reason)
187
+ else:
188
+ logger.warning(f"Ungraceful websocket close: {exc}")
189
+ finally:
190
+ # NB(nikhil): In most common cases, both websockets would have been successfully
191
+ # closed with applicable codes above, these lines are just a failsafe.
192
+ await _safe_close_ws(client_ws, logger, code=WEBSOCKET_SERVER_ERROR_CODE)
193
+ await _safe_close_ws(server_ws, logger, code=WEBSOCKET_SERVER_ERROR_CODE)
194
+
195
+
196
+ async def _attempt_websocket_proxy(
197
+ client_ws: WebSocket, proxy_client: httpx.AsyncClient, logger
198
+ ):
199
+ async with aconnect_ws("/v1/websocket", proxy_client) as server_ws: # type: ignore
200
+ await client_ws.accept()
201
+ await _handle_websocket_forwarding(client_ws, server_ws)
202
+
203
+
128
204
  async def proxy_ws(client_ws: WebSocket):
129
205
  proxy_client: httpx.AsyncClient = client_ws.app.state.proxy_client
130
206
  logger = client_ws.app.state.logger
@@ -132,37 +208,10 @@ async def proxy_ws(client_ws: WebSocket):
132
208
  for attempt in inference_retries():
133
209
  with attempt:
134
210
  try:
135
- async with aconnect_ws("/v1/websocket", proxy_client) as server_ws: # type: ignore
136
- # Unfortunate, but FastAPI and httpx-ws have slightly different abstractions
137
- # for sending data, so it's not easy to create a unified wrapper.
138
- async def forward_to_server():
139
- while True:
140
- message = await client_ws.receive()
141
- if message.get("type") == "websocket.disconnect":
142
- break
143
- if "text" in message:
144
- await server_ws.send_text(message["text"])
145
- elif "bytes" in message:
146
- await server_ws.send_bytes(message["bytes"])
147
-
148
- async def forward_to_client():
149
- while True:
150
- message = await server_ws.receive()
151
- if message is None:
152
- break
153
- if isinstance(message, TextMessage):
154
- await client_ws.send_text(message.data)
155
- elif isinstance(message, BytesMessage):
156
- await client_ws.send_bytes(message.data)
157
-
158
- await client_ws.accept()
159
- try:
160
- await asyncio.gather(forward_to_client(), forward_to_server())
161
- finally:
162
- await _safe_close_ws(client_ws, logger)
211
+ await _attempt_websocket_proxy(client_ws, proxy_client, logger)
163
212
  except httpx_ws_exceptions.HTTPXWSException as e:
164
213
  logger.warning(f"WebSocket connection rejected: {e}")
165
- await _safe_close_ws(client_ws, logger)
214
+ await _safe_close_ws(client_ws, logger, WEBSOCKET_SERVER_ERROR_CODE)
166
215
  break
167
216
 
168
217
 
@@ -6,7 +6,7 @@ loguru>=0.7.2
6
6
  python-json-logger>=2.0.2
7
7
  tenacity>=8.1.0
8
8
  # To avoid divergence, this should follow the latest release.
9
- truss==0.9.100
9
+ truss==0.11.1rc3
10
10
  uvicorn>=0.24.0
11
11
  uvloop>=0.19.0
12
12
  websockets>=10.0
@@ -18,6 +18,7 @@ _BASETEN_DOWNSTREAM_ERROR_CODE = 600
18
18
  _BASETEN_CLIENT_ERROR_CODE = 700
19
19
 
20
20
  MODEL_ERROR_MESSAGE = "Internal Server Error (in model/chainlet)."
21
+ WEBSOCKET_SERVER_ERROR_CODE = 1011
21
22
 
22
23
 
23
24
  class ModelMissingError(Exception):
@@ -6,6 +6,7 @@ import pathlib
6
6
  import time
7
7
  from typing import Iterator, List, Optional, Sequence
8
8
 
9
+ import opentelemetry.exporter.otlp.proto.http.trace_exporter as oltp_exporter
9
10
  import opentelemetry.sdk.resources as resources
10
11
  import opentelemetry.sdk.trace as sdk_trace
11
12
  import opentelemetry.sdk.trace.export as trace_export
@@ -15,6 +16,7 @@ from shared import secrets_resolver
15
16
  logger = logging.getLogger(__name__)
16
17
 
17
18
  ATTR_NAME_DURATION = "duration_sec"
19
+ OTEL_EXPORTER_OTLP_ENDPOINT = "OTEL_EXPORTER_OTLP_ENDPOINT"
18
20
  # Writing trace data to a file is only intended for testing / debugging.
19
21
  OTEL_TRACING_NDJSON_FILE = "OTEL_TRACING_NDJSON_FILE"
20
22
  # Exporting trace data to a public honeycomb instance (not our cluster collector)
@@ -65,6 +67,13 @@ def get_truss_tracer(secrets: secrets_resolver.Secrets, config) -> trace.Tracer:
65
67
  return _truss_tracer
66
68
 
67
69
  span_processors: List[sdk_trace.SpanProcessor] = []
70
+ if otlp_endpoint := os.getenv(OTEL_EXPORTER_OTLP_ENDPOINT):
71
+ if enable_tracing_data:
72
+ logger.info(f"Exporting trace data to {OTEL_EXPORTER_OTLP_ENDPOINT}.")
73
+ otlp_exporter = oltp_exporter.OTLPSpanExporter(endpoint=otlp_endpoint)
74
+ otlp_processor = sdk_trace.export.BatchSpanProcessor(otlp_exporter)
75
+ span_processors.append(otlp_processor)
76
+
68
77
  if tracing_log_file := os.getenv(OTEL_TRACING_NDJSON_FILE):
69
78
  if enable_tracing_data:
70
79
  logger.info(f"Exporting trace data to file `{tracing_log_file}`.")
@@ -72,6 +81,21 @@ def get_truss_tracer(secrets: secrets_resolver.Secrets, config) -> trace.Tracer:
72
81
  file_processor = sdk_trace.export.SimpleSpanProcessor(json_file_exporter)
73
82
  span_processors.append(file_processor)
74
83
 
84
+ if (
85
+ honeycomb_dataset := os.getenv(HONEYCOMB_DATASET)
86
+ ) and HONEYCOMB_API_KEY in secrets:
87
+ honeycomb_api_key = secrets[HONEYCOMB_API_KEY]
88
+ logger.info("Exporting trace data to honeycomb.")
89
+ honeycomb_exporter = oltp_exporter.OTLPSpanExporter(
90
+ endpoint="https://api.honeycomb.io/v1/traces",
91
+ headers={
92
+ "x-honeycomb-team": honeycomb_api_key,
93
+ "x-honeycomb-dataset": honeycomb_dataset,
94
+ },
95
+ )
96
+ honeycomb_processor = sdk_trace.export.BatchSpanProcessor(honeycomb_exporter)
97
+ span_processors.append(honeycomb_processor)
98
+
75
99
  if span_processors and enable_tracing_data:
76
100
  logger.info("Instantiating truss tracer.")
77
101
  resource = resources.Resource.create({resources.SERVICE_NAME: "truss-server"})
@@ -76,7 +76,7 @@ async def parse_body(request: Request) -> bytes:
76
76
 
77
77
 
78
78
  async def _safe_close_websocket(
79
- ws: WebSocket, reason: Optional[str], status_code: int = 1000
79
+ ws: WebSocket, status_code: int = 1000, reason: Optional[str] = None
80
80
  ) -> None:
81
81
  try:
82
82
  await ws.close(code=status_code, reason=reason)
@@ -257,14 +257,16 @@ class BasetenEndpoints:
257
257
  try:
258
258
  await ws.accept()
259
259
  await self._model.websocket(ws)
260
- await _safe_close_websocket(ws, None, status_code=1000)
260
+ await _safe_close_websocket(ws, status_code=1000, reason=None)
261
261
  except WebSocketDisconnect as ws_error:
262
262
  logging.info(
263
263
  f"Client terminated websocket connection: `{ws_error}`."
264
264
  )
265
265
  except Exception:
266
266
  await _safe_close_websocket(
267
- ws, errors.MODEL_ERROR_MESSAGE, status_code=1011
267
+ ws,
268
+ status_code=errors.WEBSOCKET_SERVER_ERROR_CODE,
269
+ reason=errors.MODEL_ERROR_MESSAGE,
268
270
  )
269
271
  raise # Re raise to let `intercept_exceptions` deal with it.
270
272
 
@@ -1,4 +1,5 @@
1
- from unittest.mock import AsyncMock, MagicMock, patch
1
+ import asyncio
2
+ from unittest.mock import AsyncMock, MagicMock, call, patch
2
3
 
3
4
  import pytest
4
5
  from fastapi import FastAPI, WebSocket
@@ -31,33 +32,38 @@ def client_ws(app):
31
32
 
32
33
  @pytest.mark.asyncio
33
34
  async def test_proxy_ws_bidirectional_messaging(client_ws):
34
- """Test that both directions of communication work and clean up properly"""
35
- client_ws.receive.side_effect = [
36
- {"type": "websocket.receive", "text": "msg1"},
37
- {"type": "websocket.receive", "text": "msg2"},
38
- {"type": "websocket.disconnect"},
39
- ]
35
+ client_queue = asyncio.Queue()
36
+ client_ws.receive = client_queue.get
40
37
 
38
+ server_queue = asyncio.Queue()
41
39
  mock_server_ws = AsyncMock(spec=AsyncWebSocketSession)
42
- mock_server_ws.receive.side_effect = [
43
- TextMessage(data="response1"),
44
- TextMessage(data="response2"),
45
- None, # server closing connection
46
- ]
40
+ mock_server_ws.receive = server_queue.get
47
41
  mock_server_ws.__aenter__.return_value = mock_server_ws
48
42
  mock_server_ws.__aexit__.return_value = None
49
43
 
44
+ client_queue.put_nowait({"type": "websocket.receive", "text": "msg1"})
45
+ client_queue.put_nowait({"type": "websocket.receive", "text": "msg2"})
46
+ server_queue.put_nowait(TextMessage(data="response1"))
47
+ server_queue.put_nowait(TextMessage(data="response2"))
48
+
50
49
  with patch(
51
50
  "truss.templates.control.control.endpoints.aconnect_ws",
52
51
  return_value=mock_server_ws,
53
52
  ):
54
- await proxy_ws(client_ws)
53
+ proxy_task = asyncio.create_task(proxy_ws(client_ws))
54
+ client_queue.put_nowait(
55
+ {"type": "websocket.disconnect", "code": 1002, "reason": "test-closure"}
56
+ )
57
+
58
+ await proxy_task
55
59
 
56
60
  assert mock_server_ws.send_text.call_count == 2
57
61
  assert mock_server_ws.send_text.call_args_list == [(("msg1",),), (("msg2",),)]
58
62
  assert client_ws.send_text.call_count == 2
59
63
  assert client_ws.send_text.call_args_list == [(("response1",),), (("response2",),)]
60
- client_ws.close.assert_called_once()
64
+
65
+ assert mock_server_ws.close.call_args_list[0] == call(1002, "test-closure")
66
+ client_ws.close.assert_called()
61
67
 
62
68
 
63
69
  @pytest.mark.asyncio
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: truss
3
- Version: 0.11.1rc2
3
+ Version: 0.11.1rc4
4
4
  Summary: A seamless bridge from model development to model delivery
5
5
  Project-URL: Repository, https://github.com/basetenlabs/truss
6
6
  Project-URL: Homepage, https://truss.baseten.co
@@ -8,7 +8,7 @@ truss/base/errors.py,sha256=zDVLEvseTChdPP0oNhBBQCtQUtZJUaof5zeWMIjqz6o,691
8
8
  truss/base/trt_llm_config.py,sha256=CRz3AqGDAyv8YpcBWXUrnfjvNAauyo3yf8ZOGVsSt6g,32782
9
9
  truss/base/truss_config.py,sha256=7CtiJIwMHtDU8Wzn8UTJUVVunD0pWFl4QUVycK2aIpY,28055
10
10
  truss/base/truss_spec.py,sha256=jFVF79CXoEEspl2kXBAPyi-rwISReIGTdobGpaIhwJw,5979
11
- truss/cli/chains_commands.py,sha256=y6pdIAGCcKOPG9bPuCXPfSA0onQm5x-tT_3blSBfPYg,16971
11
+ truss/cli/chains_commands.py,sha256=bqOXQ-0RPS66vSP_OPQdJ5dvctGiVrsGoSUMbURGdSI,16970
12
12
  truss/cli/cli.py,sha256=PaMkuwXZflkU7sa1tEoT_Zmy-iBkEZs1m4IVqcieaeo,30367
13
13
  truss/cli/remote_cli.py,sha256=G_xCKRXzgkCmkiZJhUFfsv5YSVgde1jLA5LPQitpZgI,1905
14
14
  truss/cli/train_commands.py,sha256=GDye7yXGL_nQvXAlY5MWsdj5x0zYOvcQw0Ubn14TiRU,14365
@@ -72,9 +72,9 @@ truss/templates/cache_requirements.txt,sha256=xoPoJ-OVnf1z6oq_RVM3vCr3ionByyqMLj
72
72
  truss/templates/copy_cache_files.Dockerfile.jinja,sha256=Os5zFdYLZ_AfCRGq4RcpVTObOTwL7zvmwYcvOzd_Zqo,126
73
73
  truss/templates/docker_server_requirements.txt,sha256=PyhOPKAmKW1N2vLvTfLMwsEtuGpoRrbWuNo7tT6v2Mc,18
74
74
  truss/templates/server.Dockerfile.jinja,sha256=CUYnF_hgxPGq2re7__0UPWlwzOHMoFkxp6NVKi3U16s,7071
75
- truss/templates/control/requirements.txt,sha256=Kk0tYID7trPk5gwX38Wrt2-YGWZAXFJCJRcqJ8ZzCjc,251
75
+ truss/templates/control/requirements.txt,sha256=D2kIrXfCKlWl8LO7quTUlCFYuT3Dn_MVAlCG_0YjHQY,253
76
76
  truss/templates/control/control/application.py,sha256=jYeta6hWe1SkfLL3W4IDmdYjg3ZuKqI_UagWYs5RB_E,3793
77
- truss/templates/control/control/endpoints.py,sha256=FM-sgao7I3gMoUTasM3Xq_g2LDoJQe75JxIoaQxzeNo,10031
77
+ truss/templates/control/control/endpoints.py,sha256=VQ1lvZjFvR091yRkiFdvXw1Q7PiNGXT9rJwY7_sX6yg,11828
78
78
  truss/templates/control/control/server.py,sha256=R4Y219i1dcz0kkksN8obLoX-YXWGo9iW1igindyG50c,3128
79
79
  truss/templates/control/control/helpers/context_managers.py,sha256=W6dyFgLBhPa5meqrOb3w_phMtKfaJI-GhwUfpiycDc8,413
80
80
  truss/templates/control/control/helpers/custom_types.py,sha256=n_lTudtLTpy4oPV3aDdJ4X2rh3KCV5btYO9UnTeUouQ,5471
@@ -97,13 +97,13 @@ truss/templates/server/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3
97
97
  truss/templates/server/main.py,sha256=kWXrdD8z8IpamyWxc8qcvd5ck9gM1Kz2QH5qHJCnmOQ,222
98
98
  truss/templates/server/model_wrapper.py,sha256=k75VVISwwlsx5EGb82UZsu8kCM_i6Yi3-Hd0-Kpm1yo,42055
99
99
  truss/templates/server/requirements.txt,sha256=XblmpfxAmRo3X1V_9oMj8yjdpZ5Wk-C2oa3z6nq4OGw,672
100
- truss/templates/server/truss_server.py,sha256=ob_nceeGtFPZzKKdk_ZZGLoZrJOGE6hR52xM1sPR97A,19498
100
+ truss/templates/server/truss_server.py,sha256=noXfGJMsKIhgF4oI_8LC1UHkcx8Vg8nGSITZJ_bkRFQ,19598
101
101
  truss/templates/server/common/__init__.py,sha256=qHIqr68L5Tn4mV6S-PbORpcuJ4jmtBR8aCuRTIWDvNo,85
102
- truss/templates/server/common/errors.py,sha256=qWeZlmNI8ZGbZbOIp_mtS6IKvUFIzhj3QH8zp-xTp9o,8554
102
+ truss/templates/server/common/errors.py,sha256=My0P6-Y7imVTICIhazHT0vlSu3XJDH7As06OyVzu4Do,8589
103
103
  truss/templates/server/common/patches.py,sha256=uEOzvDnXsHOkTSa8zygGYuR4GHhrFNVHNQc5peJcwvo,1393
104
104
  truss/templates/server/common/retry.py,sha256=dtz6yvwLoY0i55FnxECz57zEOKjAhGMYvvM-k9jiR9c,624
105
105
  truss/templates/server/common/schema.py,sha256=WLFtVyEKmk4whg5_gk6Gt1vOD6wM5fWKLb4zNuD0bkw,6042
106
- truss/templates/server/common/tracing.py,sha256=TDokphTO0O-b0xZLkkDMU6Z_JIsaZA0aimL6UIQB5eI,4808
106
+ truss/templates/server/common/tracing.py,sha256=XSTXNoRtV8vXwveJoX3H32go0JKnLmznZ2TtrVzIe4M,5967
107
107
  truss/templates/server/common/patches/whisper/patch.py,sha256=kDECQ-wmEpeAZFhUTQP457ofueeMsm7DgNy9tqinhJQ,2383
108
108
  truss/templates/shared/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
109
109
  truss/templates/shared/dynamic_config_resolver.py,sha256=75s42NFhQI5jL7BqlJH_UkuQS7ptbtFh13f2nh6X5Wo,920
@@ -161,7 +161,7 @@ truss/tests/remote/baseten/test_auth.py,sha256=ttu4bDnmwGfo3oiNut4HVGnh-QnjAefwZ
161
161
  truss/tests/remote/baseten/test_core.py,sha256=6NzJTDmoSUv6Muy1LFEYIUg10-cqw-hbLyeTSWcdNjY,26117
162
162
  truss/tests/remote/baseten/test_remote.py,sha256=y1qSPL1t7dBeYI3xMFn436fttG7wkYdAoENTz7qKObg,23634
163
163
  truss/tests/remote/baseten/test_service.py,sha256=ufZbtQlBNIzFCxRt_iE-APLpWbVw_3ViUpSh6H9W5nU,1945
164
- truss/tests/templates/control/control/test_endpoints.py,sha256=tGU3w8zOKC8LfWGdhp-TlV7E603KXg2xGwpqDdf8Pnw,3385
164
+ truss/tests/templates/control/control/test_endpoints.py,sha256=fxTiiCR0ltaHCL_-v-22Ie1qVgnch1lqcj3w0U3R-fk,3644
165
165
  truss/tests/templates/control/control/test_server.py,sha256=r1O3VEK9eoIL2-cg8nYLXYct_H3jf5rGp1wLT1KBdeA,9488
166
166
  truss/tests/templates/control/control/test_server_integration.py,sha256=EdDY3nLzjrRCJ5LI5yZsNCEImSRkxTL7Rn9mGnK67zA,11837
167
167
  truss/tests/templates/control/control/helpers/test_context_managers.py,sha256=3LoonRaKu_UvhaWs1eNmEQCZq-iJ3aIjI0Mn4amC8Bw,283
@@ -345,27 +345,27 @@ truss_chains/__init__.py,sha256=QDw1YwdqMaQpz5Oltu2Eq2vzEX9fDrMoqnhtbeh60i4,1278
345
345
  truss_chains/framework.py,sha256=CS7tSegPe2Q8UUT6CDkrtSrB3utr_1QN1jTEPjrj5Ug,67519
346
346
  truss_chains/private_types.py,sha256=6CaQEPawFLXjEbJ-01lqfexJtUIekF_q61LNENWegFo,8917
347
347
  truss_chains/public_api.py,sha256=0AXV6UdZIFAMycUNG_klgo4aLFmBZeKGfrulZEWzR0M,9532
348
- truss_chains/public_types.py,sha256=q8Oet6MpECW1FhWW25SCExpZhmk4cFmEsqrO30oZIMw,29112
348
+ truss_chains/public_types.py,sha256=8cBstm5m_9aN_U2WoSHQ9s2z6gTdSEl-nVgXyneU3nY,29166
349
349
  truss_chains/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
350
350
  truss_chains/pydantic_numpy.py,sha256=MG8Ji_Inwo_JSfM2n7TPj8B-nbrBlDYsY3SOeBwD8fE,4289
351
351
  truss_chains/streaming.py,sha256=DGl2LEAN67YwP7Nn9MK488KmYc4KopWmcHuE6WjyO1Q,12521
352
352
  truss_chains/utils.py,sha256=LvpCG2lnN6dqPqyX3PwLH9tyjUzqQN3N4WeEFROMHak,6291
353
353
  truss_chains/deployment/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
354
354
  truss_chains/deployment/code_gen.py,sha256=AmAUZ3h1hP3uYkl3J6o096K5RFLuBOP7kOFSnFC_C4U,32568
355
- truss_chains/deployment/deployment_client.py,sha256=7ckyh71wYVXBQtKfy4gh_MddMQWTNVpOL_FsqnaKFPo,32811
355
+ truss_chains/deployment/deployment_client.py,sha256=haFiVmQek42ewlN_YflBaRDQT4ZYbmT20tvvJOkcUX0,32899
356
356
  truss_chains/reference_code/reference_chainlet.py,sha256=5feSeqGtrHDbldkfZCfX2R5YbbW0Uhc35mhaP2pXrHw,1340
357
357
  truss_chains/reference_code/reference_model.py,sha256=emH3hb23E_nbP98I37PGp1Xk1hz3g3lQ00tiLo55cSM,322
358
358
  truss_chains/remote_chainlet/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
359
359
  truss_chains/remote_chainlet/model_skeleton.py,sha256=8ZReLOO2MLcdg7bNZ61C-6j-e68i2Z-fFlyV3sz0qH8,2376
360
360
  truss_chains/remote_chainlet/stub.py,sha256=Y2gDUzMY9WRaQNHIz-o4dfLUfFyYV9dUhIRQcfgrY8g,17209
361
- truss_chains/remote_chainlet/utils.py,sha256=O_5P-VAUvg0cegEW1uKCOf5EBwD8rEGYVoGMivOmc7k,22374
361
+ truss_chains/remote_chainlet/utils.py,sha256=xX1t3e-BsYkWrxQIqfKRl4PHGuVyW3oleWFQpXSAynI,22949
362
362
  truss_train/__init__.py,sha256=7hE6j6-u6UGzCGaNp3CsCN0kAVjBus1Ekups-Bk0fi4,837
363
363
  truss_train/definitions.py,sha256=V985HhY4rdXL10DZxpFEpze9ScxzWErMht4WwaPknGU,6789
364
364
  truss_train/deployment.py,sha256=lWWANSuzBWu2M4oK4qD7n-oVR1JKdmw2Pn5BJQHg-Ck,3074
365
365
  truss_train/loader.py,sha256=0o66EjBaHc2YY4syxxHVR4ordJWs13lNXnKjKq2wq0U,1630
366
366
  truss_train/public_api.py,sha256=9N_NstiUlmBuLUwH_fNG_1x7OhGCytZLNvqKXBlStrM,1220
367
- truss-0.11.1rc2.dist-info/METADATA,sha256=-QNAojZwEkUwM3B6Jo9KIbVMpBstsTCKx-qR1S_MFJM,6672
368
- truss-0.11.1rc2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
369
- truss-0.11.1rc2.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
370
- truss-0.11.1rc2.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
371
- truss-0.11.1rc2.dist-info/RECORD,,
367
+ truss-0.11.1rc4.dist-info/METADATA,sha256=PYD_kydnF-Z7GjTBOB0-JA0lQjQMtiBn7Y-30qyT7wY,6672
368
+ truss-0.11.1rc4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
369
+ truss-0.11.1rc4.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
370
+ truss-0.11.1rc4.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
371
+ truss-0.11.1rc4.dist-info/RECORD,,
@@ -655,7 +655,9 @@ class _Watcher:
655
655
  with framework.ChainletImporter.import_target(
656
656
  source, entrypoint
657
657
  ) as entrypoint_cls:
658
- self._deployed_chain_name = name or entrypoint_cls.__name__
658
+ self._deployed_chain_name = (
659
+ name or entrypoint_cls.meta_data.chain_name or entrypoint_cls.__name__
660
+ )
659
661
  self._chain_root = _get_chain_root(entrypoint_cls)
660
662
  chainlet_names = set(
661
663
  desc.display_name
@@ -677,7 +679,7 @@ class _Watcher:
677
679
  )
678
680
  if not chain_id:
679
681
  raise public_types.ChainsDeploymentError(
680
- f"Chain `{chain_id}` was not found."
682
+ f"Chain `{self._deployed_chain_name}` was not found."
681
683
  )
682
684
  self._status_page_url = b10_service.URLConfig.status_page_url(
683
685
  self._remote_provider.remote_url, b10_service.URLConfig.CHAIN, chain_id
@@ -473,6 +473,7 @@ class WebSocketProtocol(Protocol):
473
473
 
474
474
  async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: ...
475
475
 
476
+ async def receive(self) -> Union[str, bytes]: ...
476
477
  async def receive_text(self) -> str: ...
477
478
  async def receive_bytes(self) -> bytes: ...
478
479
  async def receive_json(self) -> Any: ...
@@ -11,6 +11,7 @@ import textwrap
11
11
  import threading
12
12
  import time
13
13
  import traceback
14
+ import typing
14
15
  from collections.abc import AsyncIterator
15
16
  from typing import (
16
17
  TYPE_CHECKING,
@@ -586,6 +587,18 @@ class WebsocketWrapperFastAPI:
586
587
  async def close(self, code: int = 1000, reason: Optional[str] = None) -> None:
587
588
  await self._websocket.close(code=code, reason=reason)
588
589
 
590
+ async def receive(self) -> Union[str, bytes]:
591
+ message = await self._websocket.receive()
592
+
593
+ if message.get("type") == "websocket.disconnect":
594
+ # NB(nikhil): Mimics FastAPI `_raise_on_disconnect`, since otherwise the user has no
595
+ # way of detecting that the client disconnected.
596
+ raise fastapi.WebSocketDisconnect(message["code"], message.get("reason"))
597
+ elif message.get("text"):
598
+ return typing.cast(str, message["text"])
599
+ else:
600
+ return typing.cast(bytes, message["bytes"])
601
+
589
602
  async def receive_text(self) -> str:
590
603
  return await self._websocket.receive_text()
591
604