flopscope-server 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flopscope_server-0.3.0/.gitignore +59 -0
- flopscope_server-0.3.0/PKG-INFO +9 -0
- flopscope_server-0.3.0/pyproject.toml +32 -0
- flopscope_server-0.3.0/src/flopscope_server/__init__.py +3 -0
- flopscope_server-0.3.0/src/flopscope_server/__main__.py +42 -0
- flopscope_server-0.3.0/src/flopscope_server/_array_store.py +97 -0
- flopscope_server-0.3.0/src/flopscope_server/_comms_tracker.py +77 -0
- flopscope_server-0.3.0/src/flopscope_server/_protocol.py +233 -0
- flopscope_server-0.3.0/src/flopscope_server/_request_handler.py +440 -0
- flopscope_server-0.3.0/src/flopscope_server/_server.py +375 -0
- flopscope_server-0.3.0/src/flopscope_server/_session.py +171 -0
- flopscope_server-0.3.0/tests/test_array_store.py +221 -0
- flopscope_server-0.3.0/tests/test_bugfixes_round2.py +178 -0
- flopscope_server-0.3.0/tests/test_bugfixes_round3.py +211 -0
- flopscope_server-0.3.0/tests/test_comms_tracker.py +350 -0
- flopscope_server-0.3.0/tests/test_new_types.py +49 -0
- flopscope_server-0.3.0/tests/test_protocol.py +289 -0
- flopscope_server-0.3.0/tests/test_request_handler.py +380 -0
- flopscope_server-0.3.0/tests/test_server.py +310 -0
- flopscope_server-0.3.0/tests/test_session.py +255 -0
- flopscope_server-0.3.0/tests/test_version_handshake.py +92 -0
- flopscope_server-0.3.0/uv.lock +534 -0
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
.aicrowd/
|
|
2
|
+
|
|
3
|
+
# Python bytecode & build artifacts
|
|
4
|
+
__pycache__/
|
|
5
|
+
*.pyc
|
|
6
|
+
*.pyo
|
|
7
|
+
*.pyd
|
|
8
|
+
*.egg-info/
|
|
9
|
+
*.egg
|
|
10
|
+
dist/
|
|
11
|
+
build/
|
|
12
|
+
*.whl
|
|
13
|
+
*.so
|
|
14
|
+
|
|
15
|
+
# Virtual environments
|
|
16
|
+
.venv/
|
|
17
|
+
venv/
|
|
18
|
+
env/
|
|
19
|
+
|
|
20
|
+
# IDE & editors
|
|
21
|
+
.vscode/
|
|
22
|
+
.idea/
|
|
23
|
+
*.swp
|
|
24
|
+
*.swo
|
|
25
|
+
*~
|
|
26
|
+
|
|
27
|
+
# Testing & coverage
|
|
28
|
+
.pytest_cache/
|
|
29
|
+
.coverage
|
|
30
|
+
coverage.xml
|
|
31
|
+
htmlcov/
|
|
32
|
+
.tox/
|
|
33
|
+
.nox/
|
|
34
|
+
|
|
35
|
+
# mkdocs generated site
|
|
36
|
+
site/
|
|
37
|
+
|
|
38
|
+
# Website (Fumadocs / Next.js)
|
|
39
|
+
website/node_modules/
|
|
40
|
+
website/.next/
|
|
41
|
+
website/out/
|
|
42
|
+
website/public/llms.txt
|
|
43
|
+
website/public/llms-full.txt
|
|
44
|
+
|
|
45
|
+
# Claude Code session data
|
|
46
|
+
.claude/
|
|
47
|
+
|
|
48
|
+
# Benchmarks (runtime output)
|
|
49
|
+
.benchmarks/
|
|
50
|
+
|
|
51
|
+
# ZMQ sockets (runtime)
|
|
52
|
+
*.sock
|
|
53
|
+
|
|
54
|
+
# OS junk
|
|
55
|
+
.DS_Store
|
|
56
|
+
Thumbs.db
|
|
57
|
+
Desktop.ini
|
|
58
|
+
weights.json
|
|
59
|
+
report.html
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: flopscope-server
|
|
3
|
+
Version: 0.3.0
|
|
4
|
+
Summary: Backend server for flopscope client-server architecture
|
|
5
|
+
Requires-Python: >=3.10
|
|
6
|
+
Requires-Dist: flopscope==0.3.0
|
|
7
|
+
Requires-Dist: msgpack>=1.0.0
|
|
8
|
+
Requires-Dist: numpy<2.5.0,>=2.0.0
|
|
9
|
+
Requires-Dist: pyzmq>=26.0.0
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "flopscope-server"
|
|
7
|
+
version = "0.3.0"
|
|
8
|
+
description = "Backend server for flopscope client-server architecture"
|
|
9
|
+
requires-python = ">=3.10"
|
|
10
|
+
dependencies = [
|
|
11
|
+
"flopscope==0.3.0",
|
|
12
|
+
"numpy>=2.0.0,<2.5.0",
|
|
13
|
+
"pyzmq>=26.0.0",
|
|
14
|
+
"msgpack>=1.0.0",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
[tool.hatch.build.targets.wheel]
|
|
18
|
+
packages = ["src/flopscope_server"]
|
|
19
|
+
|
|
20
|
+
[project.scripts]
|
|
21
|
+
flopscope-server = "flopscope_server.__main__:main"
|
|
22
|
+
|
|
23
|
+
[tool.uv.sources]
|
|
24
|
+
flopscope = { path = "../", editable = true }
|
|
25
|
+
|
|
26
|
+
[dependency-groups]
|
|
27
|
+
dev = [
|
|
28
|
+
"pytest>=9.0.3",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
[tool.pytest.ini_options]
|
|
32
|
+
testpaths = ["tests"]
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Entry point for ``python -m flopscope_server``."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
from flopscope_server._server import FlopscopeServer
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def main() -> None:
|
|
12
|
+
parser = argparse.ArgumentParser(
|
|
13
|
+
description="Flopscope budget-controlled compute server",
|
|
14
|
+
)
|
|
15
|
+
parser.add_argument(
|
|
16
|
+
"--url",
|
|
17
|
+
default="ipc:///tmp/flopscope.sock",
|
|
18
|
+
help="ZMQ endpoint to bind (default: ipc:///tmp/flopscope.sock)",
|
|
19
|
+
)
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"--timeout",
|
|
22
|
+
type=float,
|
|
23
|
+
default=60.0,
|
|
24
|
+
help="Session inactivity timeout in seconds (default: 60.0)",
|
|
25
|
+
)
|
|
26
|
+
args = parser.parse_args()
|
|
27
|
+
|
|
28
|
+
print(
|
|
29
|
+
f"[flopscope-server] binding to {args.url} (timeout={args.timeout}s)",
|
|
30
|
+
file=sys.stderr,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
server = FlopscopeServer(url=args.url, session_timeout_s=args.timeout)
|
|
34
|
+
try:
|
|
35
|
+
server.run()
|
|
36
|
+
except KeyboardInterrupt:
|
|
37
|
+
print("\n[flopscope-server] shutting down", file=sys.stderr)
|
|
38
|
+
server.stop()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
if __name__ == "__main__":
|
|
42
|
+
main()
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""ArrayStore — in-process dict-based mapping from handle IDs to numpy arrays."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
#: Maximum number of arrays allowed in a single store (configurable via env var).
|
|
11
|
+
MAX_ARRAY_COUNT = int(os.environ.get("FLOPSCOPE_MAX_ARRAY_COUNT", "100000"))
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ArrayStore:
|
|
15
|
+
"""Simple store that maps string handle IDs to numpy arrays.
|
|
16
|
+
|
|
17
|
+
Handle IDs are monotonically increasing strings of the form "a0", "a1",
|
|
18
|
+
"a2", … The counter never resets, so IDs remain unique across put/free
|
|
19
|
+
cycles.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self) -> None:
|
|
23
|
+
self._arrays: dict[str, Any] = {}
|
|
24
|
+
self._counter: int = 0
|
|
25
|
+
|
|
26
|
+
# ------------------------------------------------------------------
|
|
27
|
+
# Core operations
|
|
28
|
+
# ------------------------------------------------------------------
|
|
29
|
+
|
|
30
|
+
def put(self, arr: Any) -> str:
|
|
31
|
+
"""Store *arr* and return its handle ID.
|
|
32
|
+
|
|
33
|
+
Raises
|
|
34
|
+
------
|
|
35
|
+
MemoryError
|
|
36
|
+
If the store already contains :data:`MAX_ARRAY_COUNT` arrays.
|
|
37
|
+
"""
|
|
38
|
+
if len(self._arrays) >= MAX_ARRAY_COUNT:
|
|
39
|
+
raise MemoryError(f"array store limit reached: {MAX_ARRAY_COUNT} arrays")
|
|
40
|
+
handle = f"a{self._counter}"
|
|
41
|
+
self._arrays[handle] = arr
|
|
42
|
+
self._counter += 1
|
|
43
|
+
return handle
|
|
44
|
+
|
|
45
|
+
def get(self, handle: str) -> Any:
|
|
46
|
+
"""Return the array for *handle*.
|
|
47
|
+
|
|
48
|
+
Raises
|
|
49
|
+
------
|
|
50
|
+
KeyError
|
|
51
|
+
If *handle* is not in the store.
|
|
52
|
+
"""
|
|
53
|
+
if handle not in self._arrays:
|
|
54
|
+
raise KeyError(f"Array handle {handle!r} not found in store")
|
|
55
|
+
return self._arrays[handle]
|
|
56
|
+
|
|
57
|
+
def metadata(self, handle: str) -> dict:
|
|
58
|
+
"""Return metadata dict for *handle*.
|
|
59
|
+
|
|
60
|
+
Returns
|
|
61
|
+
-------
|
|
62
|
+
dict
|
|
63
|
+
``{"id": handle, "shape": list[int], "dtype": str}``
|
|
64
|
+
|
|
65
|
+
Raises
|
|
66
|
+
------
|
|
67
|
+
KeyError
|
|
68
|
+
If *handle* is not in the store.
|
|
69
|
+
"""
|
|
70
|
+
arr = self.get(handle) # propagates KeyError with helpful message
|
|
71
|
+
meta = {
|
|
72
|
+
"id": handle,
|
|
73
|
+
"shape": list(arr.shape),
|
|
74
|
+
"dtype": str(arr.dtype),
|
|
75
|
+
}
|
|
76
|
+
symmetry = getattr(arr, "symmetry", None)
|
|
77
|
+
if symmetry is not None:
|
|
78
|
+
meta["symmetry"] = symmetry.to_payload()
|
|
79
|
+
return meta
|
|
80
|
+
|
|
81
|
+
def free(self, handles: list[str]) -> None:
|
|
82
|
+
"""Remove arrays by handle; silently ignore unknown handles."""
|
|
83
|
+
for handle in handles:
|
|
84
|
+
self._arrays.pop(handle, None)
|
|
85
|
+
|
|
86
|
+
def clear(self) -> None:
|
|
87
|
+
"""Remove all arrays from the store."""
|
|
88
|
+
self._arrays.clear()
|
|
89
|
+
|
|
90
|
+
# ------------------------------------------------------------------
|
|
91
|
+
# Properties
|
|
92
|
+
# ------------------------------------------------------------------
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def count(self) -> int:
|
|
96
|
+
"""Number of arrays currently in the store."""
|
|
97
|
+
return len(self._arrays)
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""CommsTracker — accumulates per-request timing and byte counts for a session."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CommsTracker:
|
|
7
|
+
"""Accumulates communication and compute statistics across requests in a session."""
|
|
8
|
+
|
|
9
|
+
def __init__(self) -> None:
|
|
10
|
+
self._request_count: int = 0
|
|
11
|
+
self._fetch_count: int = 0
|
|
12
|
+
self._total_bytes_sent: int = 0
|
|
13
|
+
self._total_bytes_received: int = 0
|
|
14
|
+
self._total_comms_overhead_ns: int = 0
|
|
15
|
+
self._total_compute_time_ns: int = 0
|
|
16
|
+
|
|
17
|
+
def record_request(
|
|
18
|
+
self,
|
|
19
|
+
*,
|
|
20
|
+
bytes_received: int,
|
|
21
|
+
bytes_sent: int,
|
|
22
|
+
comms_overhead_ns: int,
|
|
23
|
+
compute_time_ns: int,
|
|
24
|
+
is_fetch: bool,
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Accumulate statistics for a single request.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
bytes_received:
|
|
31
|
+
Number of bytes received in this request.
|
|
32
|
+
bytes_sent:
|
|
33
|
+
Number of bytes sent in this request.
|
|
34
|
+
comms_overhead_ns:
|
|
35
|
+
Communications overhead for this request, in nanoseconds.
|
|
36
|
+
compute_time_ns:
|
|
37
|
+
Compute time for this request, in nanoseconds.
|
|
38
|
+
is_fetch:
|
|
39
|
+
Whether this request is a fetch (array retrieval) request.
|
|
40
|
+
"""
|
|
41
|
+
self._request_count += 1
|
|
42
|
+
if is_fetch:
|
|
43
|
+
self._fetch_count += 1
|
|
44
|
+
self._total_bytes_sent += bytes_sent
|
|
45
|
+
self._total_bytes_received += bytes_received
|
|
46
|
+
self._total_comms_overhead_ns += comms_overhead_ns
|
|
47
|
+
self._total_compute_time_ns += compute_time_ns
|
|
48
|
+
|
|
49
|
+
def summary(self) -> dict:
|
|
50
|
+
"""Return a summary of accumulated statistics.
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
dict with keys:
|
|
55
|
+
request_count: total number of requests recorded
|
|
56
|
+
fetch_count: number of fetch requests recorded
|
|
57
|
+
total_bytes_sent: total bytes sent across all requests
|
|
58
|
+
total_bytes_received: total bytes received across all requests
|
|
59
|
+
total_comms_overhead_ns: total communications overhead in nanoseconds
|
|
60
|
+
total_compute_time_ns: total compute time in nanoseconds
|
|
61
|
+
overhead_ratio: comms / (comms + compute); 0.0 if total is 0
|
|
62
|
+
"""
|
|
63
|
+
total_ns = self._total_comms_overhead_ns + self._total_compute_time_ns
|
|
64
|
+
if total_ns == 0:
|
|
65
|
+
overhead_ratio = 0.0
|
|
66
|
+
else:
|
|
67
|
+
overhead_ratio = self._total_comms_overhead_ns / total_ns
|
|
68
|
+
|
|
69
|
+
return {
|
|
70
|
+
"request_count": self._request_count,
|
|
71
|
+
"fetch_count": self._fetch_count,
|
|
72
|
+
"total_bytes_sent": self._total_bytes_sent,
|
|
73
|
+
"total_bytes_received": self._total_bytes_received,
|
|
74
|
+
"total_comms_overhead_ns": self._total_comms_overhead_ns,
|
|
75
|
+
"total_compute_time_ns": self._total_compute_time_ns,
|
|
76
|
+
"overhead_ratio": overhead_ratio,
|
|
77
|
+
}
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
"""Protocol layer for flopscope server — message encoding/decoding.
|
|
2
|
+
|
|
3
|
+
Message format: msgpack over ZMQ.
|
|
4
|
+
|
|
5
|
+
All messages are msgpack-encoded dicts. By default, msgpack returns bytes keys
|
|
6
|
+
when raw=False is not used, so decode_request handles normalisation carefully:
|
|
7
|
+
- top-level string fields (op, dtype, request_id) are decoded to str
|
|
8
|
+
- binary payload fields (data) are kept as bytes
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
import msgpack
|
|
16
|
+
|
|
17
|
+
from flopscope._registry import REGISTRY
|
|
18
|
+
|
|
19
|
+
# ---------------------------------------------------------------------------
|
|
20
|
+
# Whitelist
|
|
21
|
+
# ---------------------------------------------------------------------------
|
|
22
|
+
|
|
23
|
+
#: Protocol-level ops not in the numpy REGISTRY.
|
|
24
|
+
_PROTOCOL_OPS: frozenset[str] = frozenset(
|
|
25
|
+
{
|
|
26
|
+
"hello",
|
|
27
|
+
"budget_open",
|
|
28
|
+
"budget_close",
|
|
29
|
+
"budget_status",
|
|
30
|
+
"fetch",
|
|
31
|
+
"fetch_slice",
|
|
32
|
+
"free",
|
|
33
|
+
"create_from_data",
|
|
34
|
+
"__getitem__",
|
|
35
|
+
"astype",
|
|
36
|
+
}
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
#: Full set of permitted op names.
|
|
40
|
+
WHITELIST: frozenset[str] = frozenset(REGISTRY.keys()) | _PROTOCOL_OPS
|
|
41
|
+
|
|
42
|
+
# ---------------------------------------------------------------------------
|
|
43
|
+
# String fields that should always be decoded from bytes -> str
|
|
44
|
+
# ---------------------------------------------------------------------------
|
|
45
|
+
|
|
46
|
+
_STRING_FIELDS: frozenset[str] = frozenset({"op", "dtype", "request_id"})
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# ---------------------------------------------------------------------------
|
|
50
|
+
# Exceptions
|
|
51
|
+
# ---------------------------------------------------------------------------
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class InvalidRequestError(Exception):
|
|
55
|
+
"""Raised when a client request cannot be decoded or is not permitted."""
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# ---------------------------------------------------------------------------
|
|
59
|
+
# decode_request
|
|
60
|
+
# ---------------------------------------------------------------------------
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def decode_request(raw: bytes) -> dict:
|
|
64
|
+
"""Decode a msgpack-encoded request from the client.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
raw:
|
|
69
|
+
Raw msgpack bytes received from the client.
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
dict
|
|
74
|
+
Decoded request with top-level keys as strings.
|
|
75
|
+
Binary payload fields (e.g. ``data``) are kept as :class:`bytes`.
|
|
76
|
+
|
|
77
|
+
Raises
|
|
78
|
+
------
|
|
79
|
+
InvalidRequestError
|
|
80
|
+
If *raw* is empty, cannot be unpacked, or is missing the ``op`` field.
|
|
81
|
+
"""
|
|
82
|
+
if not raw:
|
|
83
|
+
raise InvalidRequestError("malformed request: empty bytes")
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
# Use raw=True so that we get bytes keys and can selectively decode.
|
|
87
|
+
msg_raw = msgpack.unpackb(raw, raw=True)
|
|
88
|
+
except Exception as exc:
|
|
89
|
+
raise InvalidRequestError(f"malformed request: {exc}") from exc
|
|
90
|
+
|
|
91
|
+
if not isinstance(msg_raw, dict):
|
|
92
|
+
raise InvalidRequestError("malformed request: top-level value must be a dict")
|
|
93
|
+
|
|
94
|
+
# Normalise keys and selected string values.
|
|
95
|
+
msg: dict[str, Any] = {}
|
|
96
|
+
for k, v in msg_raw.items():
|
|
97
|
+
# Decode key from bytes -> str
|
|
98
|
+
key: str = k.decode("utf-8") if isinstance(k, bytes) else str(k)
|
|
99
|
+
|
|
100
|
+
# Selectively decode known string-valued fields
|
|
101
|
+
if key in _STRING_FIELDS and isinstance(v, bytes):
|
|
102
|
+
value: Any = v.decode("utf-8")
|
|
103
|
+
else:
|
|
104
|
+
value = v
|
|
105
|
+
|
|
106
|
+
msg[key] = value
|
|
107
|
+
|
|
108
|
+
if "op" not in msg:
|
|
109
|
+
raise InvalidRequestError("malformed request: missing 'op' field")
|
|
110
|
+
|
|
111
|
+
return msg
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
# ---------------------------------------------------------------------------
|
|
115
|
+
# validate_request
|
|
116
|
+
# ---------------------------------------------------------------------------
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def validate_request(msg: dict) -> None:
|
|
120
|
+
"""Check that the op in *msg* is on the permitted whitelist.
|
|
121
|
+
|
|
122
|
+
Parameters
|
|
123
|
+
----------
|
|
124
|
+
msg:
|
|
125
|
+
Decoded request dict (as returned by :func:`decode_request`).
|
|
126
|
+
|
|
127
|
+
Raises
|
|
128
|
+
------
|
|
129
|
+
InvalidRequestError
|
|
130
|
+
If the op name is not in :data:`WHITELIST`.
|
|
131
|
+
"""
|
|
132
|
+
op = msg.get("op", "")
|
|
133
|
+
if op not in WHITELIST:
|
|
134
|
+
raise InvalidRequestError(f"unknown op: {op!r}")
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
# ---------------------------------------------------------------------------
|
|
138
|
+
# encode_response
|
|
139
|
+
# ---------------------------------------------------------------------------
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def encode_response(result: Any, budget: int, comms_overhead_ns: int) -> bytes:
|
|
143
|
+
"""Encode a successful operation response.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
result:
|
|
148
|
+
The value returned by the operation (must be msgpack-serialisable).
|
|
149
|
+
budget:
|
|
150
|
+
Remaining budget after the operation.
|
|
151
|
+
comms_overhead_ns:
|
|
152
|
+
Round-trip communications overhead in nanoseconds.
|
|
153
|
+
|
|
154
|
+
Returns
|
|
155
|
+
-------
|
|
156
|
+
bytes
|
|
157
|
+
msgpack-encoded response dict with ``status="ok"``.
|
|
158
|
+
"""
|
|
159
|
+
payload = {
|
|
160
|
+
"status": "ok",
|
|
161
|
+
"result": result,
|
|
162
|
+
"budget": budget,
|
|
163
|
+
"comms_overhead_ns": comms_overhead_ns,
|
|
164
|
+
}
|
|
165
|
+
return msgpack.packb(payload, use_bin_type=True)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
# ---------------------------------------------------------------------------
|
|
169
|
+
# encode_error_response
|
|
170
|
+
# ---------------------------------------------------------------------------
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def encode_error_response(error_type: str, message: str) -> bytes:
|
|
174
|
+
"""Encode an error response.
|
|
175
|
+
|
|
176
|
+
Parameters
|
|
177
|
+
----------
|
|
178
|
+
error_type:
|
|
179
|
+
Name of the exception class (e.g. ``"InvalidRequestError"``).
|
|
180
|
+
message:
|
|
181
|
+
Human-readable error description.
|
|
182
|
+
|
|
183
|
+
Returns
|
|
184
|
+
-------
|
|
185
|
+
bytes
|
|
186
|
+
msgpack-encoded response dict with ``status="error"``.
|
|
187
|
+
"""
|
|
188
|
+
payload = {
|
|
189
|
+
"status": "error",
|
|
190
|
+
"error_type": error_type,
|
|
191
|
+
"message": message,
|
|
192
|
+
}
|
|
193
|
+
return msgpack.packb(payload, use_bin_type=True)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
# ---------------------------------------------------------------------------
|
|
197
|
+
# encode_fetch_response
|
|
198
|
+
# ---------------------------------------------------------------------------
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def encode_fetch_response(
|
|
202
|
+
data: bytes,
|
|
203
|
+
shape: tuple[int, ...],
|
|
204
|
+
dtype: str,
|
|
205
|
+
comms_overhead_ns: int,
|
|
206
|
+
) -> bytes:
|
|
207
|
+
"""Encode a fetch response carrying raw array bytes.
|
|
208
|
+
|
|
209
|
+
Parameters
|
|
210
|
+
----------
|
|
211
|
+
data:
|
|
212
|
+
Raw array bytes (e.g. ``array.tobytes()``).
|
|
213
|
+
shape:
|
|
214
|
+
Array shape as a tuple/list of ints.
|
|
215
|
+
dtype:
|
|
216
|
+
NumPy dtype string (e.g. ``"float32"``).
|
|
217
|
+
comms_overhead_ns:
|
|
218
|
+
Round-trip communications overhead in nanoseconds.
|
|
219
|
+
|
|
220
|
+
Returns
|
|
221
|
+
-------
|
|
222
|
+
bytes
|
|
223
|
+
msgpack-encoded response dict with ``status="ok"`` and ``data`` as
|
|
224
|
+
raw bytes.
|
|
225
|
+
"""
|
|
226
|
+
payload = {
|
|
227
|
+
"status": "ok",
|
|
228
|
+
"data": data,
|
|
229
|
+
"shape": list(shape),
|
|
230
|
+
"dtype": dtype,
|
|
231
|
+
"comms_overhead_ns": comms_overhead_ns,
|
|
232
|
+
}
|
|
233
|
+
return msgpack.packb(payload, use_bin_type=True)
|