streamrelay 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- streamrelay/__init__.py +45 -0
- streamrelay/consumer.py +249 -0
- streamrelay/crypto.py +170 -0
- streamrelay/executor.py +146 -0
- streamrelay/producer.py +267 -0
- streamrelay/server.py +533 -0
- streamrelay-0.1.0.dist-info/METADATA +788 -0
- streamrelay-0.1.0.dist-info/RECORD +11 -0
- streamrelay-0.1.0.dist-info/WHEEL +4 -0
- streamrelay-0.1.0.dist-info/entry_points.txt +2 -0
- streamrelay-0.1.0.dist-info/licenses/LICENSE +153 -0
streamrelay/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""
|
|
2
|
+
streamrelay — Real-time token streaming from batch HPC executors via WebSocket relay.
|
|
3
|
+
|
|
4
|
+
Solves a fundamental gap: HPC job schedulers (Globus Compute, SLURM, PBS) execute
|
|
5
|
+
functions to completion and return a single result. This library adds a lightweight
|
|
6
|
+
bidirectional channel so tokens stream out of the compute node in real time, with
|
|
7
|
+
both ends connecting *outbound* to the relay (no inbound ports needed, no VPN).
|
|
8
|
+
|
|
9
|
+
Basic usage:
|
|
10
|
+
|
|
11
|
+
# On the HPC compute node (producer)
|
|
12
|
+
from streamrelay import RelayProducer
|
|
13
|
+
with RelayProducer(relay_url, channel_id) as relay:
|
|
14
|
+
for token in your_model_stream(prompt):
|
|
15
|
+
relay.send_token(token)
|
|
16
|
+
# "done" signal sent automatically on exit
|
|
17
|
+
|
|
18
|
+
# On your client/middleware (consumer)
|
|
19
|
+
from streamrelay import RelayConsumer
|
|
20
|
+
for token in RelayConsumer(relay_url, channel_id).stream():
|
|
21
|
+
print(token, end="", flush=True)
|
|
22
|
+
|
|
23
|
+
# High-level: submit a Globus Compute function and stream its output
|
|
24
|
+
from streamrelay import StreamingExecutor
|
|
25
|
+
async with StreamingExecutor(endpoint_id, relay_url) as executor:
|
|
26
|
+
async for token in executor.stream(my_fn, prompt="Hello"):
|
|
27
|
+
print(token, end="", flush=True)
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from streamrelay.consumer import RelayConsumer
|
|
31
|
+
from streamrelay.producer import RelayProducer
|
|
32
|
+
from streamrelay.server import start_relay
|
|
33
|
+
from streamrelay.crypto import encrypt_message, decrypt_message, generate_key
|
|
34
|
+
from streamrelay.executor import StreamingExecutor
|
|
35
|
+
|
|
36
|
+
__version__ = "0.1.0"
|
|
37
|
+
__all__ = [
|
|
38
|
+
"RelayProducer",
|
|
39
|
+
"RelayConsumer",
|
|
40
|
+
"StreamingExecutor",
|
|
41
|
+
"start_relay",
|
|
42
|
+
"generate_key",
|
|
43
|
+
"encrypt_message",
|
|
44
|
+
"decrypt_message",
|
|
45
|
+
]
|
streamrelay/consumer.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
"""
|
|
2
|
+
streamrelay.consumer — Receive tokens from the relay on the client side.
|
|
3
|
+
|
|
4
|
+
WHAT THIS FILE DOES
|
|
5
|
+
===================
|
|
6
|
+
This module runs on YOUR MACHINE (or your web server) — the application side
|
|
7
|
+
that wants to display or process tokens as they arrive from the HPC compute node.
|
|
8
|
+
|
|
9
|
+
RelayConsumer connects to the relay as a consumer and yields token strings one
|
|
10
|
+
by one as the producer sends them. From the caller's perspective, it looks just
|
|
11
|
+
like iterating over any Python generator:
|
|
12
|
+
|
|
13
|
+
for token in consumer.stream():
|
|
14
|
+
print(token, end="", flush=True)
|
|
15
|
+
|
|
16
|
+
All the WebSocket complexity — connecting, receiving, decrypting, parsing JSON,
|
|
17
|
+
detecting the end of the stream — is handled internally.
|
|
18
|
+
|
|
19
|
+
TWO ITERATION STYLES
|
|
20
|
+
====================
|
|
21
|
+
# Synchronous — use in plain scripts, notebooks, CLI tools
|
|
22
|
+
consumer = RelayConsumer(relay_url, channel_id)
|
|
23
|
+
for token in consumer.stream():
|
|
24
|
+
print(token, end="", flush=True)
|
|
25
|
+
|
|
26
|
+
# Asynchronous — use in FastAPI, aiohttp, or any asyncio application
|
|
27
|
+
async for token in RelayConsumer(relay_url, channel_id):
|
|
28
|
+
yield f"data: {token}\\n\\n" # forward as Server-Sent Events to browser
|
|
29
|
+
|
|
30
|
+
The async version also supports ``await consumer.acollect()`` to get the full
|
|
31
|
+
response as a single string.
|
|
32
|
+
|
|
33
|
+
TIMING
|
|
34
|
+
======
|
|
35
|
+
You should connect the consumer BEFORE submitting the HPC job. The relay
|
|
36
|
+
buffers any tokens that arrive before you connect, so you won't miss the
|
|
37
|
+
beginning of the response even if the job starts faster than expected.
|
|
38
|
+
|
|
39
|
+
channel_id = str(uuid.uuid4())
|
|
40
|
+
|
|
41
|
+
# Submit job first, consumer second — or consumer first, doesn't matter.
|
|
42
|
+
# The relay handles both orderings via buffering.
|
|
43
|
+
submit_slurm_job(relay_url, channel_id)
|
|
44
|
+
for token in RelayConsumer(relay_url, channel_id).stream():
|
|
45
|
+
...
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
import json
|
|
49
|
+
import logging
|
|
50
|
+
from typing import AsyncIterator, Iterator
|
|
51
|
+
|
|
52
|
+
logger = logging.getLogger(__name__)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class RelayConsumer:
|
|
56
|
+
"""
|
|
57
|
+
WebSocket client that receives tokens from the relay.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
relay_url: WebSocket URL of the relay server.
|
|
61
|
+
Example: ``"wss://relay.example.com"`` (production)
|
|
62
|
+
or ``"ws://localhost:8765"`` (local development)
|
|
63
|
+
channel_id: UUID string that pairs this consumer with its producer.
|
|
64
|
+
Must be the same value that was passed to RelayProducer.
|
|
65
|
+
encryption_key: Optional base64-encoded AES-256 key for decryption.
|
|
66
|
+
Must match the key used by the producer. When set, each received
|
|
67
|
+
message is decrypted before being parsed and yielded.
|
|
68
|
+
relay_secret: Optional shared secret for relay authentication.
|
|
69
|
+
Must match the relay server's ``--secret`` flag.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
relay_url: str,
|
|
75
|
+
channel_id: str,
|
|
76
|
+
encryption_key: str = "",
|
|
77
|
+
relay_secret: str = "",
|
|
78
|
+
):
|
|
79
|
+
self.relay_url = relay_url.rstrip("/")
|
|
80
|
+
self.channel_id = channel_id
|
|
81
|
+
self.encryption_key = encryption_key
|
|
82
|
+
self.relay_secret = relay_secret
|
|
83
|
+
|
|
84
|
+
# -----------------------------------------------------------------------
|
|
85
|
+
# Internal helpers
|
|
86
|
+
# -----------------------------------------------------------------------
|
|
87
|
+
|
|
88
|
+
def _consume_url(self) -> str:
|
|
89
|
+
"""Build the /consume/{channel_id} URL, appending ?secret= if needed."""
|
|
90
|
+
url = f"{self.relay_url}/consume/{self.channel_id}"
|
|
91
|
+
if self.relay_secret:
|
|
92
|
+
url += f"?secret={self.relay_secret}"
|
|
93
|
+
return url
|
|
94
|
+
|
|
95
|
+
def _decrypt(self, msg_str: str) -> str:
|
|
96
|
+
"""
|
|
97
|
+
Decrypt a message if encryption is configured; otherwise pass through.
|
|
98
|
+
|
|
99
|
+
The relay forwards messages as-is. If the producer encrypted them with
|
|
100
|
+
AES-256-GCM (wrapping in {"type":"enc","d":"<base64blob>"}), this
|
|
101
|
+
function unwraps and decrypts them back to the original JSON string.
|
|
102
|
+
"""
|
|
103
|
+
if self.encryption_key:
|
|
104
|
+
from streamrelay.crypto import decrypt_message
|
|
105
|
+
return decrypt_message(self.encryption_key, msg_str)
|
|
106
|
+
return msg_str
|
|
107
|
+
|
|
108
|
+
def _parse_and_yield(self, raw: str):
|
|
109
|
+
"""
|
|
110
|
+
Parse a raw WebSocket message and return the appropriate action.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
("token", content) — yield this token string to the caller
|
|
114
|
+
("done", None) — stop iteration
|
|
115
|
+
("error", message) — raise RuntimeError
|
|
116
|
+
("skip", None) — ignore this message (unknown type)
|
|
117
|
+
"""
|
|
118
|
+
msg_str = self._decrypt(raw)
|
|
119
|
+
msg = json.loads(msg_str)
|
|
120
|
+
msg_type = msg.get("type")
|
|
121
|
+
|
|
122
|
+
if msg_type == "token":
|
|
123
|
+
return ("token", msg["content"])
|
|
124
|
+
elif msg_type == "done":
|
|
125
|
+
return ("done", None)
|
|
126
|
+
elif msg_type == "error":
|
|
127
|
+
return ("error", msg.get("message", "unknown error from producer"))
|
|
128
|
+
else:
|
|
129
|
+
return ("skip", None)
|
|
130
|
+
|
|
131
|
+
# -----------------------------------------------------------------------
|
|
132
|
+
# Synchronous iterator
|
|
133
|
+
# -----------------------------------------------------------------------
|
|
134
|
+
|
|
135
|
+
def stream(self) -> Iterator[str]:
|
|
136
|
+
"""
|
|
137
|
+
Connect to the relay and yield token strings synchronously.
|
|
138
|
+
|
|
139
|
+
Blocks until each token arrives. Returns (stops iteration) when the
|
|
140
|
+
producer sends a "done" message. The WebSocket connection is closed
|
|
141
|
+
automatically when the generator exits.
|
|
142
|
+
|
|
143
|
+
Yields:
|
|
144
|
+
str: Each token string in arrival order.
|
|
145
|
+
|
|
146
|
+
Raises:
|
|
147
|
+
RuntimeError: If the producer sent an "error" message.
|
|
148
|
+
|
|
149
|
+
Example::
|
|
150
|
+
|
|
151
|
+
for token in RelayConsumer(relay_url, channel_id).stream():
|
|
152
|
+
print(token, end="", flush=True)
|
|
153
|
+
"""
|
|
154
|
+
from websockets.sync.client import connect as ws_connect
|
|
155
|
+
|
|
156
|
+
url = self._consume_url()
|
|
157
|
+
logger.debug(f"[streamrelay] consumer connecting: channel={self.channel_id[:8]}")
|
|
158
|
+
|
|
159
|
+
with ws_connect(url) as ws:
|
|
160
|
+
for raw in ws:
|
|
161
|
+
# Each raw message from the relay is one JSON string.
|
|
162
|
+
action, value = self._parse_and_yield(raw)
|
|
163
|
+
if action == "token":
|
|
164
|
+
yield value
|
|
165
|
+
elif action == "done":
|
|
166
|
+
return # clean end of stream
|
|
167
|
+
elif action == "error":
|
|
168
|
+
raise RuntimeError(f"Producer error: {value}")
|
|
169
|
+
# "skip": unknown message type, ignore and continue
|
|
170
|
+
|
|
171
|
+
# -----------------------------------------------------------------------
|
|
172
|
+
# Asynchronous iterator
|
|
173
|
+
# -----------------------------------------------------------------------
|
|
174
|
+
|
|
175
|
+
def __aiter__(self):
|
|
176
|
+
"""
|
|
177
|
+
Enable ``async for token in RelayConsumer(...)`` syntax.
|
|
178
|
+
|
|
179
|
+
Returns the async generator from astream(). This lets you use a
|
|
180
|
+
RelayConsumer directly in an ``async for`` loop without calling
|
|
181
|
+
.astream() explicitly.
|
|
182
|
+
"""
|
|
183
|
+
return self.astream()
|
|
184
|
+
|
|
185
|
+
async def astream(self) -> AsyncIterator[str]:
|
|
186
|
+
"""
|
|
187
|
+
Connect to the relay and yield token strings asynchronously.
|
|
188
|
+
|
|
189
|
+
Non-blocking: yields control to the event loop while waiting for
|
|
190
|
+
each token. Suitable for FastAPI route handlers, aiohttp servers,
|
|
191
|
+
or any asyncio application.
|
|
192
|
+
|
|
193
|
+
Yields:
|
|
194
|
+
str: Each token string in arrival order.
|
|
195
|
+
|
|
196
|
+
Raises:
|
|
197
|
+
RuntimeError: If the producer sent an "error" message.
|
|
198
|
+
|
|
199
|
+
Example (FastAPI SSE endpoint)::
|
|
200
|
+
|
|
201
|
+
@app.get("/stream")
|
|
202
|
+
async def stream():
|
|
203
|
+
async def generate():
|
|
204
|
+
async for token in RelayConsumer(relay_url, channel_id):
|
|
205
|
+
yield f"data: {token}\\n\\n"
|
|
206
|
+
return StreamingResponse(generate(), media_type="text/event-stream")
|
|
207
|
+
"""
|
|
208
|
+
from websockets.asyncio.client import connect as ws_connect
|
|
209
|
+
|
|
210
|
+
url = self._consume_url()
|
|
211
|
+
logger.debug(f"[streamrelay] async consumer connecting: channel={self.channel_id[:8]}")
|
|
212
|
+
|
|
213
|
+
async with ws_connect(url) as ws:
|
|
214
|
+
async for raw in ws:
|
|
215
|
+
action, value = self._parse_and_yield(raw)
|
|
216
|
+
if action == "token":
|
|
217
|
+
yield value
|
|
218
|
+
elif action == "done":
|
|
219
|
+
return
|
|
220
|
+
elif action == "error":
|
|
221
|
+
raise RuntimeError(f"Producer error: {value}")
|
|
222
|
+
|
|
223
|
+
# -----------------------------------------------------------------------
|
|
224
|
+
# Convenience: collect the full response as a single string
|
|
225
|
+
# -----------------------------------------------------------------------
|
|
226
|
+
|
|
227
|
+
def collect(self) -> str:
|
|
228
|
+
"""
|
|
229
|
+
Stream all tokens and join them into a single string (blocking).
|
|
230
|
+
|
|
231
|
+
Useful when you want the complete response but don't need to display
|
|
232
|
+
it incrementally.
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
str: The complete generated text.
|
|
236
|
+
"""
|
|
237
|
+
return "".join(self.stream())
|
|
238
|
+
|
|
239
|
+
async def acollect(self) -> str:
|
|
240
|
+
"""
|
|
241
|
+
Async version of collect().
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
str: The complete generated text.
|
|
245
|
+
"""
|
|
246
|
+
parts = []
|
|
247
|
+
async for token in self.astream():
|
|
248
|
+
parts.append(token)
|
|
249
|
+
return "".join(parts)
|
streamrelay/crypto.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
"""
|
|
2
|
+
streamrelay.crypto — AES-256-GCM end-to-end encryption for relay messages.
|
|
3
|
+
|
|
4
|
+
WHAT THIS FILE DOES
|
|
5
|
+
===================
|
|
6
|
+
The relay server is a public intermediary: it sees every message that flows
|
|
7
|
+
between the producer (HPC node) and the consumer (your application). By default
|
|
8
|
+
that means the relay operator can read token payloads.
|
|
9
|
+
|
|
10
|
+
This module adds optional end-to-end encryption so that the relay only ever
|
|
11
|
+
sees opaque ciphertext. The producer encrypts before sending; the consumer
|
|
12
|
+
decrypts after receiving. The relay cannot read anything.
|
|
13
|
+
|
|
14
|
+
WHY AES-256-GCM
|
|
15
|
+
===============
|
|
16
|
+
AES-256-GCM is the standard choice for this use case:
|
|
17
|
+
|
|
18
|
+
- AES-256: 256-bit key — computationally unbreakable with current hardware.
|
|
19
|
+
- GCM (Galois/Counter Mode): "authenticated encryption" — provides both
|
|
20
|
+
confidentiality (nobody can read the message) AND integrity (any tampering
|
|
21
|
+
at the relay is detected and raises an exception at decrypt time).
|
|
22
|
+
- Fresh nonce per message: GCM requires a unique nonce (number-used-once)
|
|
23
|
+
for every encryption. We use os.urandom(12) — cryptographically random
|
|
24
|
+
12 bytes — so even if you send the same token twice, the ciphertexts are
|
|
25
|
+
different. This prevents replay and pattern-analysis attacks.
|
|
26
|
+
|
|
27
|
+
WIRE FORMAT
|
|
28
|
+
===========
|
|
29
|
+
An encrypted message is a JSON string wrapping a single base64-encoded blob:
|
|
30
|
+
|
|
31
|
+
{"type": "enc", "d": "<base64(nonce[12 bytes] + ciphertext + tag[16 bytes])>"}
|
|
32
|
+
|
|
33
|
+
The nonce (12 bytes) and GCM authentication tag (16 bytes) are packed together
|
|
34
|
+
with the ciphertext into a single base64 blob. This makes it easy to pass over
|
|
35
|
+
JSON without any binary escaping.
|
|
36
|
+
|
|
37
|
+
The relay forwards this JSON string unchanged. It doesn't know or care that
|
|
38
|
+
it contains encrypted data.
|
|
39
|
+
|
|
40
|
+
BACKWARD COMPATIBILITY
|
|
41
|
+
======================
|
|
42
|
+
If decrypt_message() receives a message that is NOT of type "enc" (i.e. an
|
|
43
|
+
unencrypted message), it passes it through unchanged. This means you can
|
|
44
|
+
enable encryption on a running system without breaking existing unencrypted
|
|
45
|
+
connections.
|
|
46
|
+
|
|
47
|
+
SETUP
|
|
48
|
+
=====
|
|
49
|
+
Generate a key once and share it between the producer and consumer via
|
|
50
|
+
environment variables or a secrets manager:
|
|
51
|
+
|
|
52
|
+
python -c "from streamrelay import generate_key; print(generate_key())"
|
|
53
|
+
# Outputs something like: xK3mP9vQ... (44 characters, base64-encoded)
|
|
54
|
+
|
|
55
|
+
Store it in your .env file:
|
|
56
|
+
RELAY_ENCRYPTION_KEY=xK3mP9vQ...
|
|
57
|
+
|
|
58
|
+
Then pass it to both sides:
|
|
59
|
+
RelayProducer(relay_url, channel_id, encryption_key=os.getenv("RELAY_ENCRYPTION_KEY"))
|
|
60
|
+
RelayConsumer(relay_url, channel_id, encryption_key=os.getenv("RELAY_ENCRYPTION_KEY"))
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
import base64
|
|
64
|
+
import json
|
|
65
|
+
import os
|
|
66
|
+
|
|
67
|
+
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
|
68
|
+
|
|
69
|
+
# Standard sizes for AES-GCM (defined by NIST SP 800-38D)
|
|
70
|
+
_NONCE_SIZE = 12 # 96 bits — the recommended nonce length for GCM
|
|
71
|
+
_TAG_SIZE = 16 # 128 bits — GCM appends this authentication tag automatically
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def generate_key() -> str:
|
|
75
|
+
"""
|
|
76
|
+
Generate a random AES-256 encryption key.
|
|
77
|
+
|
|
78
|
+
Returns a base64-encoded string suitable for storing in a .env file or
|
|
79
|
+
passing as an environment variable. Run this once per deployment and
|
|
80
|
+
keep the key secret.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
str: Base64-encoded 32-byte (256-bit) key, e.g. ``"xK3mP9vQ..."``
|
|
84
|
+
|
|
85
|
+
Example::
|
|
86
|
+
|
|
87
|
+
from streamrelay import generate_key
|
|
88
|
+
key = generate_key()
|
|
89
|
+
print(key) # store this in your .env as RELAY_ENCRYPTION_KEY
|
|
90
|
+
"""
|
|
91
|
+
return base64.b64encode(os.urandom(32)).decode()
|
|
92
|
+
# os.urandom(32): 32 cryptographically random bytes from the OS entropy pool
|
|
93
|
+
# base64.b64encode: converts raw bytes to a printable ASCII string
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def encrypt_message(key_b64: str, plaintext_json: str) -> str:
|
|
97
|
+
"""
|
|
98
|
+
Encrypt a JSON string and return the relay wire format.
|
|
99
|
+
|
|
100
|
+
Takes a plaintext JSON message (e.g. ``'{"type":"token","content":"Hello"}'``)
|
|
101
|
+
and returns an encrypted JSON string in the relay wire format:
|
|
102
|
+
``'{"type": "enc", "d": "<base64blob>"}'``
|
|
103
|
+
|
|
104
|
+
The relay forwards this opaque blob. The consumer calls decrypt_message()
|
|
105
|
+
to recover the original plaintext.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
key_b64: Base64-encoded 32-byte AES-256 key (from generate_key()).
|
|
109
|
+
plaintext_json: Any JSON string to encrypt.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
str: JSON string ``{"type": "enc", "d": "<base64(nonce+ciphertext+tag)>"}``
|
|
113
|
+
"""
|
|
114
|
+
key = base64.b64decode(key_b64) # decode base64 key → 32 raw bytes
|
|
115
|
+
nonce = os.urandom(_NONCE_SIZE) # fresh random nonce for every message
|
|
116
|
+
|
|
117
|
+
aesgcm = AESGCM(key)
|
|
118
|
+
# aesgcm.encrypt() returns ciphertext + authentication tag (tag is appended
|
|
119
|
+
# automatically by the GCM implementation — we don't need to handle it separately)
|
|
120
|
+
ciphertext_with_tag = aesgcm.encrypt(nonce, plaintext_json.encode(), None)
|
|
121
|
+
|
|
122
|
+
# Pack nonce + ciphertext+tag into one base64 string.
|
|
123
|
+
# The recipient needs the nonce to decrypt, so it must travel with the message.
|
|
124
|
+
blob = base64.b64encode(nonce + ciphertext_with_tag).decode()
|
|
125
|
+
return json.dumps({"type": "enc", "d": blob})
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def decrypt_message(key_b64: str, msg_str: str) -> str:
|
|
129
|
+
"""
|
|
130
|
+
Decrypt a relay message, or pass through if it is not encrypted.
|
|
131
|
+
|
|
132
|
+
If the message has ``"type": "enc"``, decrypt it and return the original
|
|
133
|
+
plaintext JSON string. If the message has any other type, return it unchanged
|
|
134
|
+
(backward-compatible passthrough for unencrypted messages).
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
key_b64: Base64-encoded 32-byte AES-256 key (must match the producer's key).
|
|
138
|
+
msg_str: JSON string received from the relay.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
str: Decrypted inner JSON string, or original ``msg_str`` if not encrypted.
|
|
142
|
+
|
|
143
|
+
Raises:
|
|
144
|
+
cryptography.exceptions.InvalidTag: If the ciphertext was tampered with.
|
|
145
|
+
This means the relay (or someone in between) modified the message.
|
|
146
|
+
Treat this as a security event.
|
|
147
|
+
"""
|
|
148
|
+
msg = json.loads(msg_str)
|
|
149
|
+
|
|
150
|
+
if msg.get("type") != "enc":
|
|
151
|
+
# Not an encrypted message — pass through unchanged.
|
|
152
|
+
# This allows the consumer to handle both encrypted and unencrypted
|
|
153
|
+
# messages on the same channel (useful during a rolling migration).
|
|
154
|
+
return msg_str
|
|
155
|
+
|
|
156
|
+
# Decode the blob back into raw bytes
|
|
157
|
+
blob = base64.b64decode(msg["d"])
|
|
158
|
+
|
|
159
|
+
# Unpack: first 12 bytes are the nonce, the rest is ciphertext+tag
|
|
160
|
+
nonce = blob[:_NONCE_SIZE]
|
|
161
|
+
ciphertext_with_tag = blob[_NONCE_SIZE:]
|
|
162
|
+
|
|
163
|
+
key = base64.b64decode(key_b64)
|
|
164
|
+
aesgcm = AESGCM(key)
|
|
165
|
+
|
|
166
|
+
# aesgcm.decrypt() verifies the GCM authentication tag before decrypting.
|
|
167
|
+
# If the ciphertext was modified in any way, it raises InvalidTag instead
|
|
168
|
+
# of returning corrupted plaintext — this is the integrity guarantee.
|
|
169
|
+
plaintext = aesgcm.decrypt(nonce, ciphertext_with_tag, None)
|
|
170
|
+
return plaintext.decode()
|
streamrelay/executor.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""
|
|
2
|
+
streamrelay.executor — High-level API for streaming from Globus Compute.
|
|
3
|
+
|
|
4
|
+
This is the primary user-facing class. It wraps channel ID management, Globus
|
|
5
|
+
job submission, and relay consumption into a single ``async for`` loop::
|
|
6
|
+
|
|
7
|
+
from streamrelay import StreamingExecutor
|
|
8
|
+
|
|
9
|
+
async with StreamingExecutor(endpoint_id, relay_url, secret, key) as executor:
|
|
10
|
+
async for token in executor.stream(my_vllm_function, prompt="Hello"):
|
|
11
|
+
print(token, end="", flush=True)
|
|
12
|
+
|
|
13
|
+
The ``stream()`` method:
|
|
14
|
+
1. Generates a random channel ID.
|
|
15
|
+
2. Submits ``fn`` to the Globus Compute endpoint with the channel ID and relay
|
|
16
|
+
URL passed as extra keyword arguments.
|
|
17
|
+
3. Immediately connects to the relay as a consumer and yields tokens as they
|
|
18
|
+
arrive — without waiting for Globus to complete.
|
|
19
|
+
|
|
20
|
+
``fn`` must accept two extra kwargs automatically injected by the executor:
|
|
21
|
+
``relay_url`` (str) and ``channel_id`` (str).
|
|
22
|
+
Optionally also ``relay_secret`` and ``encryption_key`` if you set those.
|
|
23
|
+
|
|
24
|
+
If ``streamrelay`` is installed on the HPC endpoint workers, ``fn`` can use
|
|
25
|
+
``RelayProducer`` directly. If not, embed the inline pattern from
|
|
26
|
+
``remote_vllm_streaming`` in STREAM's ``globus_compute_client.py``.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
import uuid
|
|
30
|
+
from collections.abc import AsyncIterator
|
|
31
|
+
from typing import Callable
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class StreamingExecutor:
|
|
35
|
+
"""Submit a Globus Compute function and receive its output via relay.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
endpoint_id: Globus Compute endpoint UUID.
|
|
39
|
+
relay_url: WebSocket URL of the relay server.
|
|
40
|
+
relay_secret: Optional shared secret (must match relay's ``--secret``).
|
|
41
|
+
encryption_key: Optional base64 AES-256 key for E2E encryption.
|
|
42
|
+
consumer_timeout: Seconds to wait for the first token before timing out.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
endpoint_id: str,
|
|
48
|
+
relay_url: str,
|
|
49
|
+
relay_secret: str = "",
|
|
50
|
+
encryption_key: str = "",
|
|
51
|
+
consumer_timeout: float = 300.0,
|
|
52
|
+
):
|
|
53
|
+
self.endpoint_id = endpoint_id
|
|
54
|
+
self.relay_url = relay_url
|
|
55
|
+
self.relay_secret = relay_secret
|
|
56
|
+
self.encryption_key = encryption_key
|
|
57
|
+
self.consumer_timeout = consumer_timeout
|
|
58
|
+
self._executor = None
|
|
59
|
+
|
|
60
|
+
# ------------------------------------------------------------------
|
|
61
|
+
# Context manager
|
|
62
|
+
# ------------------------------------------------------------------
|
|
63
|
+
|
|
64
|
+
async def __aenter__(self):
|
|
65
|
+
return self
|
|
66
|
+
|
|
67
|
+
async def __aexit__(self, *args):
|
|
68
|
+
self.close()
|
|
69
|
+
|
|
70
|
+
# ------------------------------------------------------------------
|
|
71
|
+
# Lazy Globus executor
|
|
72
|
+
# ------------------------------------------------------------------
|
|
73
|
+
|
|
74
|
+
def _get_gc_executor(self):
|
|
75
|
+
if self._executor is None:
|
|
76
|
+
from globus_compute_sdk import Executor
|
|
77
|
+
|
|
78
|
+
self._executor = Executor(endpoint_id=self.endpoint_id)
|
|
79
|
+
return self._executor
|
|
80
|
+
|
|
81
|
+
def close(self):
|
|
82
|
+
"""Shut down the underlying Globus Compute executor."""
|
|
83
|
+
if self._executor is not None:
|
|
84
|
+
try:
|
|
85
|
+
self._executor.shutdown(wait=False)
|
|
86
|
+
except Exception:
|
|
87
|
+
pass
|
|
88
|
+
self._executor = None
|
|
89
|
+
|
|
90
|
+
# ------------------------------------------------------------------
|
|
91
|
+
# Main API
|
|
92
|
+
# ------------------------------------------------------------------
|
|
93
|
+
|
|
94
|
+
async def stream(self, fn: Callable, *args, **kwargs) -> AsyncIterator[str]:
|
|
95
|
+
"""Submit ``fn`` to the endpoint and stream its output token by token.
|
|
96
|
+
|
|
97
|
+
``fn`` will be called on the HPC node with ``*args, **kwargs`` PLUS
|
|
98
|
+
these additional keyword arguments injected automatically:
|
|
99
|
+
|
|
100
|
+
- ``relay_url`` — where to send tokens
|
|
101
|
+
- ``channel_id`` — this request's unique channel
|
|
102
|
+
- ``relay_secret`` — auth secret (if configured)
|
|
103
|
+
- ``encryption_key`` — E2E encryption key (if configured)
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
fn: Callable to submit. Must send tokens to the relay
|
|
107
|
+
(e.g., use :class:`~streamrelay.producer.RelayProducer`).
|
|
108
|
+
*args: Positional arguments forwarded to ``fn``.
|
|
109
|
+
**kwargs: Keyword arguments forwarded to ``fn``.
|
|
110
|
+
|
|
111
|
+
Yields:
|
|
112
|
+
str: Token strings in arrival order.
|
|
113
|
+
"""
|
|
114
|
+
channel_id = str(uuid.uuid4())
|
|
115
|
+
|
|
116
|
+
# Inject relay coordinates into the function's kwargs
|
|
117
|
+
kwargs["relay_url"] = self.relay_url
|
|
118
|
+
kwargs["channel_id"] = channel_id
|
|
119
|
+
if self.relay_secret:
|
|
120
|
+
kwargs["relay_secret"] = self.relay_secret
|
|
121
|
+
if self.encryption_key:
|
|
122
|
+
kwargs["encryption_key"] = self.encryption_key
|
|
123
|
+
|
|
124
|
+
# Submit to Globus Compute (non-blocking — returns a Future immediately)
|
|
125
|
+
gc = self._get_gc_executor()
|
|
126
|
+
future = gc.submit(fn, *args, **kwargs)
|
|
127
|
+
|
|
128
|
+
# Connect as consumer and yield tokens in real time.
|
|
129
|
+
# The relay buffers any tokens that arrive before we connect.
|
|
130
|
+
from streamrelay.consumer import RelayConsumer
|
|
131
|
+
|
|
132
|
+
consumer = RelayConsumer(
|
|
133
|
+
relay_url=self.relay_url,
|
|
134
|
+
channel_id=channel_id,
|
|
135
|
+
encryption_key=self.encryption_key,
|
|
136
|
+
relay_secret=self.relay_secret,
|
|
137
|
+
)
|
|
138
|
+
async for token in consumer.astream():
|
|
139
|
+
yield token
|
|
140
|
+
|
|
141
|
+
# After streaming, check for Globus-level errors (infrastructure faults).
|
|
142
|
+
# By this point the HPC function has already completed.
|
|
143
|
+
try:
|
|
144
|
+
future.result(timeout=10)
|
|
145
|
+
except Exception as e:
|
|
146
|
+
raise RuntimeError(f"Globus Compute reported an error: {e}") from e
|