flopscope-server 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,440 @@
1
+ """RequestHandler — dispatches decoded request dicts to flopscope functions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import re
7
+ from typing import Any
8
+
9
+ import numpy as np
10
+
11
+ import flopscope as flops
12
+ from flopscope._perm_group import SymmetryGroup, _Permutation
13
+ from flopscope_server._session import Session
14
+
15
+ _HANDLE_RE = re.compile(r"^a\d+$")
16
+
17
+
18
+ def _make_serializable(obj):
19
+ """Convert a nested structure to be msgpack-safe (no numpy types)."""
20
+ if isinstance(obj, np.ndarray):
21
+ return obj.tolist()
22
+ if isinstance(obj, np.generic):
23
+ return obj.item()
24
+ if isinstance(obj, (list, tuple)):
25
+ return [_make_serializable(item) for item in obj]
26
+ if isinstance(obj, dict):
27
+ return {k: _make_serializable(v) for k, v in obj.items()}
28
+ return obj
29
+
30
+
31
+ #: Maximum allowed array size in bytes (configurable via environment variable).
32
+ MAX_ARRAY_BYTES = int(os.environ.get("FLOPSCOPE_MAX_ARRAY_BYTES", 100 * 1024 * 1024))
33
+
34
+
35
+ class RequestHandler:
36
+ """Dispatch decoded request dicts to real flopscope functions.
37
+
38
+ Parameters
39
+ ----------
40
+ session : Session
41
+ The active session providing array storage, budget context, etc.
42
+ """
43
+
44
+ def __init__(self, session: Session) -> None:
45
+ self._session = session
46
+
47
+ # ------------------------------------------------------------------
48
+ # Public entry point
49
+ # ------------------------------------------------------------------
50
+
51
+ def handle(self, request: dict) -> dict:
52
+ """Dispatch *request* and return a response dict.
53
+
54
+ The ``request["op"]`` field determines which handler is invoked.
55
+ """
56
+ try:
57
+ op = request["op"]
58
+
59
+ if op == "fetch":
60
+ return self._handle_fetch(request)
61
+ if op == "fetch_slice":
62
+ return self._handle_fetch_slice(request)
63
+ if op == "free":
64
+ return self._handle_free(request)
65
+ if op == "budget_status":
66
+ return self._handle_budget_status()
67
+ if op == "create_from_data":
68
+ return self._handle_create_from_data(request)
69
+ if op == "__getitem__":
70
+ return self._handle_getitem(request)
71
+
72
+ # Any other op — flopscope function call
73
+ return self._handle_flopscope_op(request)
74
+
75
+ except flops.BudgetExhaustedError as e:
76
+ return {
77
+ "status": "error",
78
+ "error_type": "BudgetExhaustedError",
79
+ "message": str(e),
80
+ }
81
+ except flops.NoBudgetContextError as e:
82
+ return {
83
+ "status": "error",
84
+ "error_type": "NoBudgetContextError",
85
+ "message": str(e),
86
+ }
87
+ except flops.SymmetryError as e:
88
+ return {"status": "error", "error_type": "SymmetryError", "message": str(e)}
89
+ except flops.UnsupportedFunctionError as e:
90
+ return {
91
+ "status": "error",
92
+ "error_type": "UnsupportedFunctionError",
93
+ "message": str(e),
94
+ }
95
+ except (ValueError, TypeError) as e:
96
+ return {
97
+ "status": "error",
98
+ "error_type": type(e).__name__,
99
+ "message": str(e),
100
+ }
101
+ except KeyError as e:
102
+ return {"status": "error", "error_type": "KeyError", "message": str(e)}
103
+ except Exception as e:
104
+ return {
105
+ "status": "error",
106
+ "error_type": "FlopscopeServerError",
107
+ "message": f"internal server error: {type(e).__name__}: {e}",
108
+ }
109
+
110
+ # ------------------------------------------------------------------
111
+ # Built-in ops
112
+ # ------------------------------------------------------------------
113
+
114
+ def _handle_fetch(self, request: dict) -> dict:
115
+ # Support both direct "id" field and kwargs-based "handle_id"
116
+ handle = request.get("id")
117
+ if handle is None:
118
+ kwargs = request.get("kwargs") or {}
119
+ handle = kwargs.get("handle_id")
120
+ if handle is None:
121
+ raise KeyError("fetch requires 'id' or kwargs.handle_id")
122
+ arr = self._session.get_array(handle)
123
+ return {
124
+ "status": "ok",
125
+ "data": arr.tobytes(),
126
+ "shape": list(arr.shape),
127
+ "dtype": str(arr.dtype),
128
+ }
129
+
130
+ def _handle_fetch_slice(self, request: dict) -> dict:
131
+ arr = self._session.get_array(request["id"])
132
+ slices = tuple(slice(*s) for s in request["slices"])
133
+ sliced = arr[slices]
134
+
135
+ if np.ndim(sliced) == 0:
136
+ return {
137
+ "status": "ok",
138
+ "data": sliced.tobytes(),
139
+ "shape": [],
140
+ "dtype": str(sliced.dtype),
141
+ }
142
+
143
+ return {
144
+ "status": "ok",
145
+ "data": sliced.tobytes(),
146
+ "shape": list(sliced.shape),
147
+ "dtype": str(sliced.dtype),
148
+ }
149
+
150
+ def _handle_free(self, request: dict) -> dict:
151
+ # Support both direct "ids" field and kwargs-based "handles"
152
+ ids = request.get("ids")
153
+ if ids is None:
154
+ kwargs = request.get("kwargs") or {}
155
+ ids = kwargs.get("handles", [])
156
+ self._session.free_arrays(ids)
157
+ return {"status": "ok"}
158
+
159
+ def _handle_budget_status(self) -> dict:
160
+ return {"status": "ok", "result": self._session.budget_status()}
161
+
162
+ def _handle_create_from_data(self, request: dict) -> dict:
163
+ # Support both direct fields and args-based [data, shape, dtype]
164
+ if "data" in request:
165
+ data = request["data"]
166
+ shape = request["shape"]
167
+ dtype = request["dtype"]
168
+ else:
169
+ args = request.get("args", [])
170
+ if len(args) >= 3:
171
+ data, shape, dtype = args[0], args[1], args[2]
172
+ else:
173
+ raise ValueError("create_from_data requires data, shape, dtype")
174
+ # Ensure dtype is a string (may be bytes from msgpack)
175
+ if isinstance(dtype, bytes):
176
+ dtype = dtype.decode("utf-8")
177
+ if len(data) > MAX_ARRAY_BYTES:
178
+ return {
179
+ "status": "error",
180
+ "error_type": "ValueError",
181
+ "message": f"array too large: {len(data)} bytes exceeds {MAX_ARRAY_BYTES} byte limit",
182
+ }
183
+ arr = np.frombuffer(data, dtype=dtype).reshape(shape).copy()
184
+ handle = self._session.store_array(arr)
185
+ meta = self._session.array_metadata(handle)
186
+ return {
187
+ "status": "ok",
188
+ "result": meta,
189
+ "budget": self._session.budget_status(),
190
+ }
191
+
192
+ # ------------------------------------------------------------------
193
+ # __getitem__ dispatch
194
+ # ------------------------------------------------------------------
195
+
196
+ def _handle_getitem(self, request: dict) -> dict:
197
+ """Handle array indexing: arr[key] on the server side."""
198
+ args = request.get("args") or []
199
+ if len(args) < 2:
200
+ raise ValueError("__getitem__ requires [handle, key]")
201
+ arr = self._resolve_arg(args[0])
202
+ key = self._decode_index_key(args[1])
203
+ result = arr[key]
204
+ return self._pack_result(result)
205
+
206
+ # ------------------------------------------------------------------
207
+ # Flopscope function dispatch
208
+ # ------------------------------------------------------------------
209
+
210
+ def _handle_flopscope_op(self, request: dict) -> dict:
211
+ op = request["op"]
212
+ raw_args = request.get("args") or []
213
+ kwargs = request.get("kwargs") or {}
214
+
215
+ # Special-case: astype is an ndarray method, not a module function
216
+ if op == "astype":
217
+ arr = self._resolve_arg(raw_args[0])
218
+ dtype = raw_args[1] if len(raw_args) > 1 else kwargs.get("dtype")
219
+ if isinstance(dtype, bytes):
220
+ dtype = dtype.decode("utf-8")
221
+ result = arr.astype(dtype)
222
+ return self._pack_result(result)
223
+
224
+ func = _get_flopscope_func(op)
225
+ resolved_args = [self._resolve_arg(a) for a in raw_args]
226
+ resolved_kwargs = {k: self._resolve_arg(v) for k, v in kwargs.items()}
227
+
228
+ result = func(*resolved_args, **resolved_kwargs)
229
+
230
+ return self._pack_result(result)
231
+
232
+ # ------------------------------------------------------------------
233
+ # Argument resolution
234
+ # ------------------------------------------------------------------
235
+
236
+ def _resolve_arg(self, arg: Any) -> Any:
237
+ """Resolve a single argument: handle IDs become arrays, rest pass through."""
238
+ if isinstance(arg, str) and _HANDLE_RE.match(arg):
239
+ return self._session.get_array(arg)
240
+ # Support {"__handle__": "a0"} dict format from the client
241
+ if isinstance(arg, dict):
242
+ handle = arg.get("__handle__")
243
+ if handle is None:
244
+ # Try bytes key variant (msgpack may leave keys as bytes)
245
+ handle = arg.get(b"__handle__")
246
+ if handle is not None:
247
+ if isinstance(handle, bytes):
248
+ handle = handle.decode("utf-8")
249
+ return self._session.get_array(handle)
250
+ # SymmetryGroup wire format
251
+ pg_data = arg.get("__symmetry_group__") or arg.get(b"__symmetry_group__")
252
+ if pg_data is not None:
253
+ if isinstance(pg_data, dict):
254
+ pg_data = {
255
+ (k.decode("utf-8") if isinstance(k, bytes) else k): v
256
+ for k, v in pg_data.items()
257
+ }
258
+ return SymmetryGroup.from_payload(pg_data)
259
+ # Recurse into lists/tuples so that e.g. concatenate([a, b]) works
260
+ if isinstance(arg, (list, tuple)):
261
+ resolved = [self._resolve_arg(item) for item in arg]
262
+ return type(arg)(resolved) if isinstance(arg, tuple) else resolved
263
+ return arg
264
+
265
+ # ------------------------------------------------------------------
266
+ # Result packing
267
+ # ------------------------------------------------------------------
268
+
269
+ def _decode_index_key(self, raw_key):
270
+ """Decode a serialised index key from the client (instance method).
271
+
272
+ Supports handle dicts for fancy indexing with RemoteArrays.
273
+ """
274
+ if isinstance(raw_key, dict):
275
+ # Handle dict: {"__handle__": "a0"} for fancy indexing
276
+ handle = raw_key.get("__handle__") or raw_key.get(b"__handle__")
277
+ if handle is not None:
278
+ if isinstance(handle, bytes):
279
+ handle = handle.decode()
280
+ return self._session.get_array(handle)
281
+ if "__slice__" in raw_key:
282
+ parts = raw_key["__slice__"]
283
+ return slice(*[None if p is None else int(p) for p in parts])
284
+ if b"__slice__" in raw_key:
285
+ parts = raw_key[b"__slice__"]
286
+ return slice(*[None if p is None else int(p) for p in parts])
287
+ if isinstance(raw_key, list):
288
+ decoded = [self._decode_index_key(item) for item in raw_key]
289
+ if any(isinstance(d, slice) for d in decoded) or len(decoded) > 1:
290
+ return tuple(decoded)
291
+ return decoded
292
+ if isinstance(raw_key, (int, float)):
293
+ return int(raw_key)
294
+ return raw_key
295
+
296
+ # ------------------------------------------------------------------
297
+ # Result packing
298
+ # ------------------------------------------------------------------
299
+
300
+ def _pack_result(self, result: Any) -> dict:
301
+ """Pack a flopscope function result into a response dict."""
302
+ budget = self._session.budget_status()
303
+
304
+ if isinstance(result, np.ndarray):
305
+ if result.nbytes > MAX_ARRAY_BYTES:
306
+ return {
307
+ "status": "error",
308
+ "error_type": "ValueError",
309
+ "message": f"result array too large: {result.nbytes} bytes exceeds {MAX_ARRAY_BYTES} byte limit",
310
+ }
311
+ handle = self._session.store_array(result)
312
+ meta = self._session.array_metadata(handle)
313
+ return {"status": "ok", "result": meta, "budget": budget}
314
+
315
+ if isinstance(result, (tuple, list)):
316
+ items = []
317
+ for r in result:
318
+ if isinstance(r, np.ndarray):
319
+ handle = self._session.store_array(r)
320
+ items.append(self._session.array_metadata(handle))
321
+ elif isinstance(r, np.generic):
322
+ items.append({"value": r.item(), "dtype": str(r.dtype)})
323
+ elif isinstance(r, (int, float)):
324
+ dtype_str = "float64" if isinstance(r, float) else "int64"
325
+ items.append({"value": r, "dtype": dtype_str})
326
+ elif isinstance(r, str):
327
+ items.append({"value": r, "dtype": "str"})
328
+ elif isinstance(r, (list, tuple)):
329
+ # Nested list/tuple (e.g., from einsum_path) — convert to JSON-safe
330
+ items.append({"value": _make_serializable(r), "dtype": "object"})
331
+ else:
332
+ items.append({"value": r})
333
+ return {"status": "ok", "result": {"multi": items}, "budget": budget}
334
+
335
+ # Scalar or other value
336
+ if isinstance(result, np.generic):
337
+ dtype_str = str(result.dtype)
338
+ return {
339
+ "status": "ok",
340
+ "result": {"value": result.item(), "dtype": dtype_str},
341
+ "budget": budget,
342
+ }
343
+ if isinstance(result, bool):
344
+ return {
345
+ "status": "ok",
346
+ "result": {"value": result, "dtype": "bool"},
347
+ "budget": budget,
348
+ }
349
+ if isinstance(result, int):
350
+ return {
351
+ "status": "ok",
352
+ "result": {"value": result, "dtype": "int64"},
353
+ "budget": budget,
354
+ }
355
+ if isinstance(result, float):
356
+ return {
357
+ "status": "ok",
358
+ "result": {"value": result, "dtype": "float64"},
359
+ "budget": budget,
360
+ }
361
+ if isinstance(result, str):
362
+ return {
363
+ "status": "ok",
364
+ "result": {"value": result, "dtype": "str"},
365
+ "budget": budget,
366
+ }
367
+ if isinstance(result, np.dtype):
368
+ return {
369
+ "status": "ok",
370
+ "result": {"value": str(result), "dtype": "str"},
371
+ "budget": budget,
372
+ }
373
+ # Fallback: try to make it serializable
374
+ try:
375
+ serializable = _make_serializable(result)
376
+ return {"status": "ok", "result": {"value": serializable}, "budget": budget}
377
+ except Exception:
378
+ return {
379
+ "status": "ok",
380
+ "result": {"value": str(result), "dtype": "str"},
381
+ "budget": budget,
382
+ }
383
+
384
+
385
+ # ---------------------------------------------------------------------------
386
+ # Helper
387
+ # ---------------------------------------------------------------------------
388
+
389
+
390
+ def _get_flopscope_func(op_name: str):
391
+ """Look up a flopscope op by dotted name (e.g. 'linalg.svd', 'stats.norm.pdf').
392
+
393
+ Post-rebrand layout (JAX-style):
394
+ - Numpy-shaped ops (einsum, linalg.*, fft.*, random.*) live under
395
+ :mod:`flopscope.numpy`.
396
+ - Stats distributions (stats.norm.*, stats.uniform.*, ...) live under
397
+ :mod:`flopscope.stats` (top-level — closer in spirit to scipy.stats
398
+ than numpy).
399
+
400
+ We try ``flopscope.numpy`` first; if the first component is not under
401
+ numpy we fall back to top-level ``flopscope`` so submodules like
402
+ ``stats`` continue to resolve.
403
+ """
404
+ import flopscope.numpy as fnp
405
+
406
+ parts = op_name.split(".")
407
+ for base in (fnp, flops):
408
+ obj = base
409
+ try:
410
+ for part in parts:
411
+ obj = getattr(obj, part)
412
+ except AttributeError:
413
+ continue
414
+ return obj
415
+ raise AttributeError(f"flopscope does not provide {op_name!r}")
416
+
417
+
418
+ def _decode_index_key(raw_key):
419
+ """Decode a serialised index key from the client.
420
+
421
+ Supports:
422
+ - int / float -> int
423
+ - ``{"__slice__": [start, stop, step]}`` -> slice
424
+ - list of the above -> tuple (for multi-dimensional indexing)
425
+ """
426
+ if isinstance(raw_key, dict):
427
+ if "__slice__" in raw_key:
428
+ parts = raw_key["__slice__"]
429
+ return slice(*[None if p is None else int(p) for p in parts])
430
+ if isinstance(raw_key, list):
431
+ decoded = [_decode_index_key(item) for item in raw_key]
432
+ # A list of slices/ints -> tuple for multi-dim indexing
433
+ if any(isinstance(d, slice) for d in decoded) or len(decoded) > 1:
434
+ return tuple(decoded)
435
+ # Single-element list: could be the key itself being a list
436
+ # (e.g., fancy indexing) -- keep as list
437
+ return decoded
438
+ if isinstance(raw_key, (int, float)):
439
+ return int(raw_key)
440
+ return raw_key