flopscope-server 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3 @@
1
+ """Flopscope backend server — executes numpy operations on behalf of remote clients."""
2
+
3
+ __version__ = "0.3.0"
@@ -0,0 +1,42 @@
1
+ """Entry point for ``python -m flopscope_server``."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import sys
7
+
8
+ from flopscope_server._server import FlopscopeServer
9
+
10
+
11
+ def main() -> None:
12
+ parser = argparse.ArgumentParser(
13
+ description="Flopscope budget-controlled compute server",
14
+ )
15
+ parser.add_argument(
16
+ "--url",
17
+ default="ipc:///tmp/flopscope.sock",
18
+ help="ZMQ endpoint to bind (default: ipc:///tmp/flopscope.sock)",
19
+ )
20
+ parser.add_argument(
21
+ "--timeout",
22
+ type=float,
23
+ default=60.0,
24
+ help="Session inactivity timeout in seconds (default: 60.0)",
25
+ )
26
+ args = parser.parse_args()
27
+
28
+ print(
29
+ f"[flopscope-server] binding to {args.url} (timeout={args.timeout}s)",
30
+ file=sys.stderr,
31
+ )
32
+
33
+ server = FlopscopeServer(url=args.url, session_timeout_s=args.timeout)
34
+ try:
35
+ server.run()
36
+ except KeyboardInterrupt:
37
+ print("\n[flopscope-server] shutting down", file=sys.stderr)
38
+ server.stop()
39
+
40
+
41
+ if __name__ == "__main__":
42
+ main()
@@ -0,0 +1,97 @@
1
+ """ArrayStore — in-process dict-based mapping from handle IDs to numpy arrays."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+
10
+ #: Maximum number of arrays allowed in a single store (configurable via env var).
11
+ MAX_ARRAY_COUNT = int(os.environ.get("FLOPSCOPE_MAX_ARRAY_COUNT", "100000"))
12
+
13
+
14
+ class ArrayStore:
15
+ """Simple store that maps string handle IDs to numpy arrays.
16
+
17
+ Handle IDs are monotonically increasing strings of the form "a0", "a1",
18
+ "a2", … The counter never resets, so IDs remain unique across put/free
19
+ cycles.
20
+ """
21
+
22
+ def __init__(self) -> None:
23
+ self._arrays: dict[str, Any] = {}
24
+ self._counter: int = 0
25
+
26
+ # ------------------------------------------------------------------
27
+ # Core operations
28
+ # ------------------------------------------------------------------
29
+
30
+ def put(self, arr: Any) -> str:
31
+ """Store *arr* and return its handle ID.
32
+
33
+ Raises
34
+ ------
35
+ MemoryError
36
+ If the store already contains :data:`MAX_ARRAY_COUNT` arrays.
37
+ """
38
+ if len(self._arrays) >= MAX_ARRAY_COUNT:
39
+ raise MemoryError(f"array store limit reached: {MAX_ARRAY_COUNT} arrays")
40
+ handle = f"a{self._counter}"
41
+ self._arrays[handle] = arr
42
+ self._counter += 1
43
+ return handle
44
+
45
+ def get(self, handle: str) -> Any:
46
+ """Return the array for *handle*.
47
+
48
+ Raises
49
+ ------
50
+ KeyError
51
+ If *handle* is not in the store.
52
+ """
53
+ if handle not in self._arrays:
54
+ raise KeyError(f"Array handle {handle!r} not found in store")
55
+ return self._arrays[handle]
56
+
57
+ def metadata(self, handle: str) -> dict:
58
+ """Return metadata dict for *handle*.
59
+
60
+ Returns
61
+ -------
62
+ dict
63
+ ``{"id": handle, "shape": list[int], "dtype": str}``
64
+
65
+ Raises
66
+ ------
67
+ KeyError
68
+ If *handle* is not in the store.
69
+ """
70
+ arr = self.get(handle) # propagates KeyError with helpful message
71
+ meta = {
72
+ "id": handle,
73
+ "shape": list(arr.shape),
74
+ "dtype": str(arr.dtype),
75
+ }
76
+ symmetry = getattr(arr, "symmetry", None)
77
+ if symmetry is not None:
78
+ meta["symmetry"] = symmetry.to_payload()
79
+ return meta
80
+
81
+ def free(self, handles: list[str]) -> None:
82
+ """Remove arrays by handle; silently ignore unknown handles."""
83
+ for handle in handles:
84
+ self._arrays.pop(handle, None)
85
+
86
+ def clear(self) -> None:
87
+ """Remove all arrays from the store."""
88
+ self._arrays.clear()
89
+
90
+ # ------------------------------------------------------------------
91
+ # Properties
92
+ # ------------------------------------------------------------------
93
+
94
+ @property
95
+ def count(self) -> int:
96
+ """Number of arrays currently in the store."""
97
+ return len(self._arrays)
@@ -0,0 +1,77 @@
1
+ """CommsTracker — accumulates per-request timing and byte counts for a session."""
2
+
3
+ from __future__ import annotations
4
+
5
+
6
+ class CommsTracker:
7
+ """Accumulates communication and compute statistics across requests in a session."""
8
+
9
+ def __init__(self) -> None:
10
+ self._request_count: int = 0
11
+ self._fetch_count: int = 0
12
+ self._total_bytes_sent: int = 0
13
+ self._total_bytes_received: int = 0
14
+ self._total_comms_overhead_ns: int = 0
15
+ self._total_compute_time_ns: int = 0
16
+
17
+ def record_request(
18
+ self,
19
+ *,
20
+ bytes_received: int,
21
+ bytes_sent: int,
22
+ comms_overhead_ns: int,
23
+ compute_time_ns: int,
24
+ is_fetch: bool,
25
+ ) -> None:
26
+ """Accumulate statistics for a single request.
27
+
28
+ Parameters
29
+ ----------
30
+ bytes_received:
31
+ Number of bytes received in this request.
32
+ bytes_sent:
33
+ Number of bytes sent in this request.
34
+ comms_overhead_ns:
35
+ Communications overhead for this request, in nanoseconds.
36
+ compute_time_ns:
37
+ Compute time for this request, in nanoseconds.
38
+ is_fetch:
39
+ Whether this request is a fetch (array retrieval) request.
40
+ """
41
+ self._request_count += 1
42
+ if is_fetch:
43
+ self._fetch_count += 1
44
+ self._total_bytes_sent += bytes_sent
45
+ self._total_bytes_received += bytes_received
46
+ self._total_comms_overhead_ns += comms_overhead_ns
47
+ self._total_compute_time_ns += compute_time_ns
48
+
49
+ def summary(self) -> dict:
50
+ """Return a summary of accumulated statistics.
51
+
52
+ Returns
53
+ -------
54
+ dict with keys:
55
+ request_count: total number of requests recorded
56
+ fetch_count: number of fetch requests recorded
57
+ total_bytes_sent: total bytes sent across all requests
58
+ total_bytes_received: total bytes received across all requests
59
+ total_comms_overhead_ns: total communications overhead in nanoseconds
60
+ total_compute_time_ns: total compute time in nanoseconds
61
+ overhead_ratio: comms / (comms + compute); 0.0 if total is 0
62
+ """
63
+ total_ns = self._total_comms_overhead_ns + self._total_compute_time_ns
64
+ if total_ns == 0:
65
+ overhead_ratio = 0.0
66
+ else:
67
+ overhead_ratio = self._total_comms_overhead_ns / total_ns
68
+
69
+ return {
70
+ "request_count": self._request_count,
71
+ "fetch_count": self._fetch_count,
72
+ "total_bytes_sent": self._total_bytes_sent,
73
+ "total_bytes_received": self._total_bytes_received,
74
+ "total_comms_overhead_ns": self._total_comms_overhead_ns,
75
+ "total_compute_time_ns": self._total_compute_time_ns,
76
+ "overhead_ratio": overhead_ratio,
77
+ }
@@ -0,0 +1,233 @@
1
+ """Protocol layer for flopscope server — message encoding/decoding.
2
+
3
+ Message format: msgpack over ZMQ.
4
+
5
+ All messages are msgpack-encoded dicts. By default, msgpack returns bytes keys
6
+ when raw=False is not used, so decode_request handles normalisation carefully:
7
+ - top-level string fields (op, dtype, request_id) are decoded to str
8
+ - binary payload fields (data) are kept as bytes
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import Any
14
+
15
+ import msgpack
16
+
17
+ from flopscope._registry import REGISTRY
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # Whitelist
21
+ # ---------------------------------------------------------------------------
22
+
23
+ #: Protocol-level ops not in the numpy REGISTRY.
24
+ _PROTOCOL_OPS: frozenset[str] = frozenset(
25
+ {
26
+ "hello",
27
+ "budget_open",
28
+ "budget_close",
29
+ "budget_status",
30
+ "fetch",
31
+ "fetch_slice",
32
+ "free",
33
+ "create_from_data",
34
+ "__getitem__",
35
+ "astype",
36
+ }
37
+ )
38
+
39
+ #: Full set of permitted op names.
40
+ WHITELIST: frozenset[str] = frozenset(REGISTRY.keys()) | _PROTOCOL_OPS
41
+
42
+ # ---------------------------------------------------------------------------
43
+ # String fields that should always be decoded from bytes -> str
44
+ # ---------------------------------------------------------------------------
45
+
46
+ _STRING_FIELDS: frozenset[str] = frozenset({"op", "dtype", "request_id"})
47
+
48
+
49
+ # ---------------------------------------------------------------------------
50
+ # Exceptions
51
+ # ---------------------------------------------------------------------------
52
+
53
+
54
+ class InvalidRequestError(Exception):
55
+ """Raised when a client request cannot be decoded or is not permitted."""
56
+
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # decode_request
60
+ # ---------------------------------------------------------------------------
61
+
62
+
63
+ def decode_request(raw: bytes) -> dict:
64
+ """Decode a msgpack-encoded request from the client.
65
+
66
+ Parameters
67
+ ----------
68
+ raw:
69
+ Raw msgpack bytes received from the client.
70
+
71
+ Returns
72
+ -------
73
+ dict
74
+ Decoded request with top-level keys as strings.
75
+ Binary payload fields (e.g. ``data``) are kept as :class:`bytes`.
76
+
77
+ Raises
78
+ ------
79
+ InvalidRequestError
80
+ If *raw* is empty, cannot be unpacked, or is missing the ``op`` field.
81
+ """
82
+ if not raw:
83
+ raise InvalidRequestError("malformed request: empty bytes")
84
+
85
+ try:
86
+ # Use raw=True so that we get bytes keys and can selectively decode.
87
+ msg_raw = msgpack.unpackb(raw, raw=True)
88
+ except Exception as exc:
89
+ raise InvalidRequestError(f"malformed request: {exc}") from exc
90
+
91
+ if not isinstance(msg_raw, dict):
92
+ raise InvalidRequestError("malformed request: top-level value must be a dict")
93
+
94
+ # Normalise keys and selected string values.
95
+ msg: dict[str, Any] = {}
96
+ for k, v in msg_raw.items():
97
+ # Decode key from bytes -> str
98
+ key: str = k.decode("utf-8") if isinstance(k, bytes) else str(k)
99
+
100
+ # Selectively decode known string-valued fields
101
+ if key in _STRING_FIELDS and isinstance(v, bytes):
102
+ value: Any = v.decode("utf-8")
103
+ else:
104
+ value = v
105
+
106
+ msg[key] = value
107
+
108
+ if "op" not in msg:
109
+ raise InvalidRequestError("malformed request: missing 'op' field")
110
+
111
+ return msg
112
+
113
+
114
+ # ---------------------------------------------------------------------------
115
+ # validate_request
116
+ # ---------------------------------------------------------------------------
117
+
118
+
119
+ def validate_request(msg: dict) -> None:
120
+ """Check that the op in *msg* is on the permitted whitelist.
121
+
122
+ Parameters
123
+ ----------
124
+ msg:
125
+ Decoded request dict (as returned by :func:`decode_request`).
126
+
127
+ Raises
128
+ ------
129
+ InvalidRequestError
130
+ If the op name is not in :data:`WHITELIST`.
131
+ """
132
+ op = msg.get("op", "")
133
+ if op not in WHITELIST:
134
+ raise InvalidRequestError(f"unknown op: {op!r}")
135
+
136
+
137
+ # ---------------------------------------------------------------------------
138
+ # encode_response
139
+ # ---------------------------------------------------------------------------
140
+
141
+
142
+ def encode_response(result: Any, budget: int, comms_overhead_ns: int) -> bytes:
143
+ """Encode a successful operation response.
144
+
145
+ Parameters
146
+ ----------
147
+ result:
148
+ The value returned by the operation (must be msgpack-serialisable).
149
+ budget:
150
+ Remaining budget after the operation.
151
+ comms_overhead_ns:
152
+ Round-trip communications overhead in nanoseconds.
153
+
154
+ Returns
155
+ -------
156
+ bytes
157
+ msgpack-encoded response dict with ``status="ok"``.
158
+ """
159
+ payload = {
160
+ "status": "ok",
161
+ "result": result,
162
+ "budget": budget,
163
+ "comms_overhead_ns": comms_overhead_ns,
164
+ }
165
+ return msgpack.packb(payload, use_bin_type=True)
166
+
167
+
168
+ # ---------------------------------------------------------------------------
169
+ # encode_error_response
170
+ # ---------------------------------------------------------------------------
171
+
172
+
173
+ def encode_error_response(error_type: str, message: str) -> bytes:
174
+ """Encode an error response.
175
+
176
+ Parameters
177
+ ----------
178
+ error_type:
179
+ Name of the exception class (e.g. ``"InvalidRequestError"``).
180
+ message:
181
+ Human-readable error description.
182
+
183
+ Returns
184
+ -------
185
+ bytes
186
+ msgpack-encoded response dict with ``status="error"``.
187
+ """
188
+ payload = {
189
+ "status": "error",
190
+ "error_type": error_type,
191
+ "message": message,
192
+ }
193
+ return msgpack.packb(payload, use_bin_type=True)
194
+
195
+
196
+ # ---------------------------------------------------------------------------
197
+ # encode_fetch_response
198
+ # ---------------------------------------------------------------------------
199
+
200
+
201
+ def encode_fetch_response(
202
+ data: bytes,
203
+ shape: tuple[int, ...],
204
+ dtype: str,
205
+ comms_overhead_ns: int,
206
+ ) -> bytes:
207
+ """Encode a fetch response carrying raw array bytes.
208
+
209
+ Parameters
210
+ ----------
211
+ data:
212
+ Raw array bytes (e.g. ``array.tobytes()``).
213
+ shape:
214
+ Array shape as a tuple/list of ints.
215
+ dtype:
216
+ NumPy dtype string (e.g. ``"float32"``).
217
+ comms_overhead_ns:
218
+ Round-trip communications overhead in nanoseconds.
219
+
220
+ Returns
221
+ -------
222
+ bytes
223
+ msgpack-encoded response dict with ``status="ok"`` and ``data`` as
224
+ raw bytes.
225
+ """
226
+ payload = {
227
+ "status": "ok",
228
+ "data": data,
229
+ "shape": list(shape),
230
+ "dtype": dtype,
231
+ "comms_overhead_ns": comms_overhead_ns,
232
+ }
233
+ return msgpack.packb(payload, use_bin_type=True)