flopscope-server 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flopscope_server/__init__.py +3 -0
- flopscope_server/__main__.py +42 -0
- flopscope_server/_array_store.py +97 -0
- flopscope_server/_comms_tracker.py +77 -0
- flopscope_server/_protocol.py +233 -0
- flopscope_server/_request_handler.py +440 -0
- flopscope_server/_server.py +375 -0
- flopscope_server/_session.py +171 -0
- flopscope_server-0.3.0.dist-info/METADATA +9 -0
- flopscope_server-0.3.0.dist-info/RECORD +12 -0
- flopscope_server-0.3.0.dist-info/WHEEL +4 -0
- flopscope_server-0.3.0.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,440 @@
|
|
|
1
|
+
"""RequestHandler — dispatches decoded request dicts to flopscope functions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
import flopscope as flops
|
|
12
|
+
from flopscope._perm_group import SymmetryGroup, _Permutation
|
|
13
|
+
from flopscope_server._session import Session
|
|
14
|
+
|
|
15
|
+
_HANDLE_RE = re.compile(r"^a\d+$")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _make_serializable(obj):
|
|
19
|
+
"""Convert a nested structure to be msgpack-safe (no numpy types)."""
|
|
20
|
+
if isinstance(obj, np.ndarray):
|
|
21
|
+
return obj.tolist()
|
|
22
|
+
if isinstance(obj, np.generic):
|
|
23
|
+
return obj.item()
|
|
24
|
+
if isinstance(obj, (list, tuple)):
|
|
25
|
+
return [_make_serializable(item) for item in obj]
|
|
26
|
+
if isinstance(obj, dict):
|
|
27
|
+
return {k: _make_serializable(v) for k, v in obj.items()}
|
|
28
|
+
return obj
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
#: Maximum allowed array size in bytes (configurable via environment variable).
|
|
32
|
+
MAX_ARRAY_BYTES = int(os.environ.get("FLOPSCOPE_MAX_ARRAY_BYTES", 100 * 1024 * 1024))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class RequestHandler:
|
|
36
|
+
"""Dispatch decoded request dicts to real flopscope functions.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
session : Session
|
|
41
|
+
The active session providing array storage, budget context, etc.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, session: Session) -> None:
|
|
45
|
+
self._session = session
|
|
46
|
+
|
|
47
|
+
# ------------------------------------------------------------------
|
|
48
|
+
# Public entry point
|
|
49
|
+
# ------------------------------------------------------------------
|
|
50
|
+
|
|
51
|
+
def handle(self, request: dict) -> dict:
|
|
52
|
+
"""Dispatch *request* and return a response dict.
|
|
53
|
+
|
|
54
|
+
The ``request["op"]`` field determines which handler is invoked.
|
|
55
|
+
"""
|
|
56
|
+
try:
|
|
57
|
+
op = request["op"]
|
|
58
|
+
|
|
59
|
+
if op == "fetch":
|
|
60
|
+
return self._handle_fetch(request)
|
|
61
|
+
if op == "fetch_slice":
|
|
62
|
+
return self._handle_fetch_slice(request)
|
|
63
|
+
if op == "free":
|
|
64
|
+
return self._handle_free(request)
|
|
65
|
+
if op == "budget_status":
|
|
66
|
+
return self._handle_budget_status()
|
|
67
|
+
if op == "create_from_data":
|
|
68
|
+
return self._handle_create_from_data(request)
|
|
69
|
+
if op == "__getitem__":
|
|
70
|
+
return self._handle_getitem(request)
|
|
71
|
+
|
|
72
|
+
# Any other op — flopscope function call
|
|
73
|
+
return self._handle_flopscope_op(request)
|
|
74
|
+
|
|
75
|
+
except flops.BudgetExhaustedError as e:
|
|
76
|
+
return {
|
|
77
|
+
"status": "error",
|
|
78
|
+
"error_type": "BudgetExhaustedError",
|
|
79
|
+
"message": str(e),
|
|
80
|
+
}
|
|
81
|
+
except flops.NoBudgetContextError as e:
|
|
82
|
+
return {
|
|
83
|
+
"status": "error",
|
|
84
|
+
"error_type": "NoBudgetContextError",
|
|
85
|
+
"message": str(e),
|
|
86
|
+
}
|
|
87
|
+
except flops.SymmetryError as e:
|
|
88
|
+
return {"status": "error", "error_type": "SymmetryError", "message": str(e)}
|
|
89
|
+
except flops.UnsupportedFunctionError as e:
|
|
90
|
+
return {
|
|
91
|
+
"status": "error",
|
|
92
|
+
"error_type": "UnsupportedFunctionError",
|
|
93
|
+
"message": str(e),
|
|
94
|
+
}
|
|
95
|
+
except (ValueError, TypeError) as e:
|
|
96
|
+
return {
|
|
97
|
+
"status": "error",
|
|
98
|
+
"error_type": type(e).__name__,
|
|
99
|
+
"message": str(e),
|
|
100
|
+
}
|
|
101
|
+
except KeyError as e:
|
|
102
|
+
return {"status": "error", "error_type": "KeyError", "message": str(e)}
|
|
103
|
+
except Exception as e:
|
|
104
|
+
return {
|
|
105
|
+
"status": "error",
|
|
106
|
+
"error_type": "FlopscopeServerError",
|
|
107
|
+
"message": f"internal server error: {type(e).__name__}: {e}",
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
# ------------------------------------------------------------------
|
|
111
|
+
# Built-in ops
|
|
112
|
+
# ------------------------------------------------------------------
|
|
113
|
+
|
|
114
|
+
def _handle_fetch(self, request: dict) -> dict:
|
|
115
|
+
# Support both direct "id" field and kwargs-based "handle_id"
|
|
116
|
+
handle = request.get("id")
|
|
117
|
+
if handle is None:
|
|
118
|
+
kwargs = request.get("kwargs") or {}
|
|
119
|
+
handle = kwargs.get("handle_id")
|
|
120
|
+
if handle is None:
|
|
121
|
+
raise KeyError("fetch requires 'id' or kwargs.handle_id")
|
|
122
|
+
arr = self._session.get_array(handle)
|
|
123
|
+
return {
|
|
124
|
+
"status": "ok",
|
|
125
|
+
"data": arr.tobytes(),
|
|
126
|
+
"shape": list(arr.shape),
|
|
127
|
+
"dtype": str(arr.dtype),
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
def _handle_fetch_slice(self, request: dict) -> dict:
|
|
131
|
+
arr = self._session.get_array(request["id"])
|
|
132
|
+
slices = tuple(slice(*s) for s in request["slices"])
|
|
133
|
+
sliced = arr[slices]
|
|
134
|
+
|
|
135
|
+
if np.ndim(sliced) == 0:
|
|
136
|
+
return {
|
|
137
|
+
"status": "ok",
|
|
138
|
+
"data": sliced.tobytes(),
|
|
139
|
+
"shape": [],
|
|
140
|
+
"dtype": str(sliced.dtype),
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
return {
|
|
144
|
+
"status": "ok",
|
|
145
|
+
"data": sliced.tobytes(),
|
|
146
|
+
"shape": list(sliced.shape),
|
|
147
|
+
"dtype": str(sliced.dtype),
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
def _handle_free(self, request: dict) -> dict:
|
|
151
|
+
# Support both direct "ids" field and kwargs-based "handles"
|
|
152
|
+
ids = request.get("ids")
|
|
153
|
+
if ids is None:
|
|
154
|
+
kwargs = request.get("kwargs") or {}
|
|
155
|
+
ids = kwargs.get("handles", [])
|
|
156
|
+
self._session.free_arrays(ids)
|
|
157
|
+
return {"status": "ok"}
|
|
158
|
+
|
|
159
|
+
def _handle_budget_status(self) -> dict:
|
|
160
|
+
return {"status": "ok", "result": self._session.budget_status()}
|
|
161
|
+
|
|
162
|
+
def _handle_create_from_data(self, request: dict) -> dict:
|
|
163
|
+
# Support both direct fields and args-based [data, shape, dtype]
|
|
164
|
+
if "data" in request:
|
|
165
|
+
data = request["data"]
|
|
166
|
+
shape = request["shape"]
|
|
167
|
+
dtype = request["dtype"]
|
|
168
|
+
else:
|
|
169
|
+
args = request.get("args", [])
|
|
170
|
+
if len(args) >= 3:
|
|
171
|
+
data, shape, dtype = args[0], args[1], args[2]
|
|
172
|
+
else:
|
|
173
|
+
raise ValueError("create_from_data requires data, shape, dtype")
|
|
174
|
+
# Ensure dtype is a string (may be bytes from msgpack)
|
|
175
|
+
if isinstance(dtype, bytes):
|
|
176
|
+
dtype = dtype.decode("utf-8")
|
|
177
|
+
if len(data) > MAX_ARRAY_BYTES:
|
|
178
|
+
return {
|
|
179
|
+
"status": "error",
|
|
180
|
+
"error_type": "ValueError",
|
|
181
|
+
"message": f"array too large: {len(data)} bytes exceeds {MAX_ARRAY_BYTES} byte limit",
|
|
182
|
+
}
|
|
183
|
+
arr = np.frombuffer(data, dtype=dtype).reshape(shape).copy()
|
|
184
|
+
handle = self._session.store_array(arr)
|
|
185
|
+
meta = self._session.array_metadata(handle)
|
|
186
|
+
return {
|
|
187
|
+
"status": "ok",
|
|
188
|
+
"result": meta,
|
|
189
|
+
"budget": self._session.budget_status(),
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
# ------------------------------------------------------------------
|
|
193
|
+
# __getitem__ dispatch
|
|
194
|
+
# ------------------------------------------------------------------
|
|
195
|
+
|
|
196
|
+
def _handle_getitem(self, request: dict) -> dict:
|
|
197
|
+
"""Handle array indexing: arr[key] on the server side."""
|
|
198
|
+
args = request.get("args") or []
|
|
199
|
+
if len(args) < 2:
|
|
200
|
+
raise ValueError("__getitem__ requires [handle, key]")
|
|
201
|
+
arr = self._resolve_arg(args[0])
|
|
202
|
+
key = self._decode_index_key(args[1])
|
|
203
|
+
result = arr[key]
|
|
204
|
+
return self._pack_result(result)
|
|
205
|
+
|
|
206
|
+
# ------------------------------------------------------------------
|
|
207
|
+
# Flopscope function dispatch
|
|
208
|
+
# ------------------------------------------------------------------
|
|
209
|
+
|
|
210
|
+
def _handle_flopscope_op(self, request: dict) -> dict:
|
|
211
|
+
op = request["op"]
|
|
212
|
+
raw_args = request.get("args") or []
|
|
213
|
+
kwargs = request.get("kwargs") or {}
|
|
214
|
+
|
|
215
|
+
# Special-case: astype is an ndarray method, not a module function
|
|
216
|
+
if op == "astype":
|
|
217
|
+
arr = self._resolve_arg(raw_args[0])
|
|
218
|
+
dtype = raw_args[1] if len(raw_args) > 1 else kwargs.get("dtype")
|
|
219
|
+
if isinstance(dtype, bytes):
|
|
220
|
+
dtype = dtype.decode("utf-8")
|
|
221
|
+
result = arr.astype(dtype)
|
|
222
|
+
return self._pack_result(result)
|
|
223
|
+
|
|
224
|
+
func = _get_flopscope_func(op)
|
|
225
|
+
resolved_args = [self._resolve_arg(a) for a in raw_args]
|
|
226
|
+
resolved_kwargs = {k: self._resolve_arg(v) for k, v in kwargs.items()}
|
|
227
|
+
|
|
228
|
+
result = func(*resolved_args, **resolved_kwargs)
|
|
229
|
+
|
|
230
|
+
return self._pack_result(result)
|
|
231
|
+
|
|
232
|
+
# ------------------------------------------------------------------
|
|
233
|
+
# Argument resolution
|
|
234
|
+
# ------------------------------------------------------------------
|
|
235
|
+
|
|
236
|
+
def _resolve_arg(self, arg: Any) -> Any:
|
|
237
|
+
"""Resolve a single argument: handle IDs become arrays, rest pass through."""
|
|
238
|
+
if isinstance(arg, str) and _HANDLE_RE.match(arg):
|
|
239
|
+
return self._session.get_array(arg)
|
|
240
|
+
# Support {"__handle__": "a0"} dict format from the client
|
|
241
|
+
if isinstance(arg, dict):
|
|
242
|
+
handle = arg.get("__handle__")
|
|
243
|
+
if handle is None:
|
|
244
|
+
# Try bytes key variant (msgpack may leave keys as bytes)
|
|
245
|
+
handle = arg.get(b"__handle__")
|
|
246
|
+
if handle is not None:
|
|
247
|
+
if isinstance(handle, bytes):
|
|
248
|
+
handle = handle.decode("utf-8")
|
|
249
|
+
return self._session.get_array(handle)
|
|
250
|
+
# SymmetryGroup wire format
|
|
251
|
+
pg_data = arg.get("__symmetry_group__") or arg.get(b"__symmetry_group__")
|
|
252
|
+
if pg_data is not None:
|
|
253
|
+
if isinstance(pg_data, dict):
|
|
254
|
+
pg_data = {
|
|
255
|
+
(k.decode("utf-8") if isinstance(k, bytes) else k): v
|
|
256
|
+
for k, v in pg_data.items()
|
|
257
|
+
}
|
|
258
|
+
return SymmetryGroup.from_payload(pg_data)
|
|
259
|
+
# Recurse into lists/tuples so that e.g. concatenate([a, b]) works
|
|
260
|
+
if isinstance(arg, (list, tuple)):
|
|
261
|
+
resolved = [self._resolve_arg(item) for item in arg]
|
|
262
|
+
return type(arg)(resolved) if isinstance(arg, tuple) else resolved
|
|
263
|
+
return arg
|
|
264
|
+
|
|
265
|
+
# ------------------------------------------------------------------
|
|
266
|
+
# Result packing
|
|
267
|
+
# ------------------------------------------------------------------
|
|
268
|
+
|
|
269
|
+
def _decode_index_key(self, raw_key):
|
|
270
|
+
"""Decode a serialised index key from the client (instance method).
|
|
271
|
+
|
|
272
|
+
Supports handle dicts for fancy indexing with RemoteArrays.
|
|
273
|
+
"""
|
|
274
|
+
if isinstance(raw_key, dict):
|
|
275
|
+
# Handle dict: {"__handle__": "a0"} for fancy indexing
|
|
276
|
+
handle = raw_key.get("__handle__") or raw_key.get(b"__handle__")
|
|
277
|
+
if handle is not None:
|
|
278
|
+
if isinstance(handle, bytes):
|
|
279
|
+
handle = handle.decode()
|
|
280
|
+
return self._session.get_array(handle)
|
|
281
|
+
if "__slice__" in raw_key:
|
|
282
|
+
parts = raw_key["__slice__"]
|
|
283
|
+
return slice(*[None if p is None else int(p) for p in parts])
|
|
284
|
+
if b"__slice__" in raw_key:
|
|
285
|
+
parts = raw_key[b"__slice__"]
|
|
286
|
+
return slice(*[None if p is None else int(p) for p in parts])
|
|
287
|
+
if isinstance(raw_key, list):
|
|
288
|
+
decoded = [self._decode_index_key(item) for item in raw_key]
|
|
289
|
+
if any(isinstance(d, slice) for d in decoded) or len(decoded) > 1:
|
|
290
|
+
return tuple(decoded)
|
|
291
|
+
return decoded
|
|
292
|
+
if isinstance(raw_key, (int, float)):
|
|
293
|
+
return int(raw_key)
|
|
294
|
+
return raw_key
|
|
295
|
+
|
|
296
|
+
# ------------------------------------------------------------------
|
|
297
|
+
# Result packing
|
|
298
|
+
# ------------------------------------------------------------------
|
|
299
|
+
|
|
300
|
+
def _pack_result(self, result: Any) -> dict:
|
|
301
|
+
"""Pack a flopscope function result into a response dict."""
|
|
302
|
+
budget = self._session.budget_status()
|
|
303
|
+
|
|
304
|
+
if isinstance(result, np.ndarray):
|
|
305
|
+
if result.nbytes > MAX_ARRAY_BYTES:
|
|
306
|
+
return {
|
|
307
|
+
"status": "error",
|
|
308
|
+
"error_type": "ValueError",
|
|
309
|
+
"message": f"result array too large: {result.nbytes} bytes exceeds {MAX_ARRAY_BYTES} byte limit",
|
|
310
|
+
}
|
|
311
|
+
handle = self._session.store_array(result)
|
|
312
|
+
meta = self._session.array_metadata(handle)
|
|
313
|
+
return {"status": "ok", "result": meta, "budget": budget}
|
|
314
|
+
|
|
315
|
+
if isinstance(result, (tuple, list)):
|
|
316
|
+
items = []
|
|
317
|
+
for r in result:
|
|
318
|
+
if isinstance(r, np.ndarray):
|
|
319
|
+
handle = self._session.store_array(r)
|
|
320
|
+
items.append(self._session.array_metadata(handle))
|
|
321
|
+
elif isinstance(r, np.generic):
|
|
322
|
+
items.append({"value": r.item(), "dtype": str(r.dtype)})
|
|
323
|
+
elif isinstance(r, (int, float)):
|
|
324
|
+
dtype_str = "float64" if isinstance(r, float) else "int64"
|
|
325
|
+
items.append({"value": r, "dtype": dtype_str})
|
|
326
|
+
elif isinstance(r, str):
|
|
327
|
+
items.append({"value": r, "dtype": "str"})
|
|
328
|
+
elif isinstance(r, (list, tuple)):
|
|
329
|
+
# Nested list/tuple (e.g., from einsum_path) — convert to JSON-safe
|
|
330
|
+
items.append({"value": _make_serializable(r), "dtype": "object"})
|
|
331
|
+
else:
|
|
332
|
+
items.append({"value": r})
|
|
333
|
+
return {"status": "ok", "result": {"multi": items}, "budget": budget}
|
|
334
|
+
|
|
335
|
+
# Scalar or other value
|
|
336
|
+
if isinstance(result, np.generic):
|
|
337
|
+
dtype_str = str(result.dtype)
|
|
338
|
+
return {
|
|
339
|
+
"status": "ok",
|
|
340
|
+
"result": {"value": result.item(), "dtype": dtype_str},
|
|
341
|
+
"budget": budget,
|
|
342
|
+
}
|
|
343
|
+
if isinstance(result, bool):
|
|
344
|
+
return {
|
|
345
|
+
"status": "ok",
|
|
346
|
+
"result": {"value": result, "dtype": "bool"},
|
|
347
|
+
"budget": budget,
|
|
348
|
+
}
|
|
349
|
+
if isinstance(result, int):
|
|
350
|
+
return {
|
|
351
|
+
"status": "ok",
|
|
352
|
+
"result": {"value": result, "dtype": "int64"},
|
|
353
|
+
"budget": budget,
|
|
354
|
+
}
|
|
355
|
+
if isinstance(result, float):
|
|
356
|
+
return {
|
|
357
|
+
"status": "ok",
|
|
358
|
+
"result": {"value": result, "dtype": "float64"},
|
|
359
|
+
"budget": budget,
|
|
360
|
+
}
|
|
361
|
+
if isinstance(result, str):
|
|
362
|
+
return {
|
|
363
|
+
"status": "ok",
|
|
364
|
+
"result": {"value": result, "dtype": "str"},
|
|
365
|
+
"budget": budget,
|
|
366
|
+
}
|
|
367
|
+
if isinstance(result, np.dtype):
|
|
368
|
+
return {
|
|
369
|
+
"status": "ok",
|
|
370
|
+
"result": {"value": str(result), "dtype": "str"},
|
|
371
|
+
"budget": budget,
|
|
372
|
+
}
|
|
373
|
+
# Fallback: try to make it serializable
|
|
374
|
+
try:
|
|
375
|
+
serializable = _make_serializable(result)
|
|
376
|
+
return {"status": "ok", "result": {"value": serializable}, "budget": budget}
|
|
377
|
+
except Exception:
|
|
378
|
+
return {
|
|
379
|
+
"status": "ok",
|
|
380
|
+
"result": {"value": str(result), "dtype": "str"},
|
|
381
|
+
"budget": budget,
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
# ---------------------------------------------------------------------------
|
|
386
|
+
# Helper
|
|
387
|
+
# ---------------------------------------------------------------------------
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def _get_flopscope_func(op_name: str):
|
|
391
|
+
"""Look up a flopscope op by dotted name (e.g. 'linalg.svd', 'stats.norm.pdf').
|
|
392
|
+
|
|
393
|
+
Post-rebrand layout (JAX-style):
|
|
394
|
+
- Numpy-shaped ops (einsum, linalg.*, fft.*, random.*) live under
|
|
395
|
+
:mod:`flopscope.numpy`.
|
|
396
|
+
- Stats distributions (stats.norm.*, stats.uniform.*, ...) live under
|
|
397
|
+
:mod:`flopscope.stats` (top-level — closer in spirit to scipy.stats
|
|
398
|
+
than numpy).
|
|
399
|
+
|
|
400
|
+
We try ``flopscope.numpy`` first; if the first component is not under
|
|
401
|
+
numpy we fall back to top-level ``flopscope`` so submodules like
|
|
402
|
+
``stats`` continue to resolve.
|
|
403
|
+
"""
|
|
404
|
+
import flopscope.numpy as fnp
|
|
405
|
+
|
|
406
|
+
parts = op_name.split(".")
|
|
407
|
+
for base in (fnp, flops):
|
|
408
|
+
obj = base
|
|
409
|
+
try:
|
|
410
|
+
for part in parts:
|
|
411
|
+
obj = getattr(obj, part)
|
|
412
|
+
except AttributeError:
|
|
413
|
+
continue
|
|
414
|
+
return obj
|
|
415
|
+
raise AttributeError(f"flopscope does not provide {op_name!r}")
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def _decode_index_key(raw_key):
|
|
419
|
+
"""Decode a serialised index key from the client.
|
|
420
|
+
|
|
421
|
+
Supports:
|
|
422
|
+
- int / float -> int
|
|
423
|
+
- ``{"__slice__": [start, stop, step]}`` -> slice
|
|
424
|
+
- list of the above -> tuple (for multi-dimensional indexing)
|
|
425
|
+
"""
|
|
426
|
+
if isinstance(raw_key, dict):
|
|
427
|
+
if "__slice__" in raw_key:
|
|
428
|
+
parts = raw_key["__slice__"]
|
|
429
|
+
return slice(*[None if p is None else int(p) for p in parts])
|
|
430
|
+
if isinstance(raw_key, list):
|
|
431
|
+
decoded = [_decode_index_key(item) for item in raw_key]
|
|
432
|
+
# A list of slices/ints -> tuple for multi-dim indexing
|
|
433
|
+
if any(isinstance(d, slice) for d in decoded) or len(decoded) > 1:
|
|
434
|
+
return tuple(decoded)
|
|
435
|
+
# Single-element list: could be the key itself being a list
|
|
436
|
+
# (e.g., fancy indexing) -- keep as list
|
|
437
|
+
return decoded
|
|
438
|
+
if isinstance(raw_key, (int, float)):
|
|
439
|
+
return int(raw_key)
|
|
440
|
+
return raw_key
|