flopscope-client 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flopscope/__init__.py +335 -0
- flopscope/_budget.py +324 -0
- flopscope/_comms_tracker.py +46 -0
- flopscope/_connection.py +187 -0
- flopscope/_constants.py +9 -0
- flopscope/_display.py +283 -0
- flopscope/_getattr.py +52 -0
- flopscope/_math_compat.py +25 -0
- flopscope/_perm_group.py +721 -0
- flopscope/_protocol.py +88 -0
- flopscope/_registry.py +66 -0
- flopscope/_registry_data.py +609 -0
- flopscope/_remote_array.py +691 -0
- flopscope/_weights.py +126 -0
- flopscope/data/__init__.py +1 -0
- flopscope/data/default_weights.json +462 -0
- flopscope/errors.py +107 -0
- flopscope/fft/__init__.py +9 -0
- flopscope/flops.py +140 -0
- flopscope/linalg/__init__.py +68 -0
- flopscope/random/__init__.py +45 -0
- flopscope/stats/__init__.py +82 -0
- flopscope_client-0.4.0.dist-info/METADATA +91 -0
- flopscope_client-0.4.0.dist-info/RECORD +25 -0
- flopscope_client-0.4.0.dist-info/WHEEL +4 -0
flopscope/__init__.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
"""flopscope — transparent proxy to a remote flopscope server.
|
|
2
|
+
|
|
3
|
+
This module exposes a numpy-like API where every operation is dispatched
|
|
4
|
+
to a remote server over ZMQ. Participants use it as::
|
|
5
|
+
|
|
6
|
+
import flopscope as flops
|
|
7
|
+
import flopscope.numpy as fnp
|
|
8
|
+
|
|
9
|
+
with flops.BudgetContext(flop_budget=1_000_000) as ctx:
|
|
10
|
+
a = fnp.array([[1.0, 2.0], [3.0, 4.0]])
|
|
11
|
+
b = fnp.zeros((2, 2))
|
|
12
|
+
c = fnp.add(a, b)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import builtins
|
|
18
|
+
import struct
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
__version__ = "0.4.0"
|
|
22
|
+
|
|
23
|
+
# ---------------------------------------------------------------------------
|
|
24
|
+
# Errors
|
|
25
|
+
# ---------------------------------------------------------------------------
|
|
26
|
+
|
|
27
|
+
# ---------------------------------------------------------------------------
|
|
28
|
+
# Budget
|
|
29
|
+
# ---------------------------------------------------------------------------
|
|
30
|
+
from flopscope._budget import ( # noqa: E402
|
|
31
|
+
BudgetContext,
|
|
32
|
+
OpRecord,
|
|
33
|
+
budget,
|
|
34
|
+
budget_summary_dict,
|
|
35
|
+
)
|
|
36
|
+
from flopscope._display import budget_live, budget_summary # noqa: E402
|
|
37
|
+
from flopscope._math_compat import e, inf, nan, pi # noqa: E402
|
|
38
|
+
from flopscope._perm_group import SymmetryGroup # noqa: E402
|
|
39
|
+
|
|
40
|
+
# ---------------------------------------------------------------------------
|
|
41
|
+
# Remote types
|
|
42
|
+
# ---------------------------------------------------------------------------
|
|
43
|
+
from flopscope._remote_array import ( # noqa: E402
|
|
44
|
+
_DTYPE_INFO,
|
|
45
|
+
RemoteArray,
|
|
46
|
+
RemoteScalar,
|
|
47
|
+
_encode_arg,
|
|
48
|
+
_result_from_response,
|
|
49
|
+
)
|
|
50
|
+
from flopscope.errors import ( # noqa: E402
|
|
51
|
+
BudgetExhaustedError,
|
|
52
|
+
FlopscopeError,
|
|
53
|
+
FlopscopeServerError,
|
|
54
|
+
FlopscopeWarning,
|
|
55
|
+
NoBudgetContextError,
|
|
56
|
+
SymmetryError,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Alias: ``fnp.ndarray`` refers to the RemoteArray class.
|
|
60
|
+
ndarray = RemoteArray
|
|
61
|
+
|
|
62
|
+
# ---------------------------------------------------------------------------
|
|
63
|
+
# Connection / protocol (private)
|
|
64
|
+
# ---------------------------------------------------------------------------
|
|
65
|
+
|
|
66
|
+
# ---------------------------------------------------------------------------
|
|
67
|
+
# Submodules (imported so ``fnp.linalg``, ``fnp.random``, ``fnp.fft`` work)
|
|
68
|
+
# ---------------------------------------------------------------------------
|
|
69
|
+
from flopscope import (
|
|
70
|
+
fft, # noqa: E402, F401
|
|
71
|
+
flops, # noqa: E402, F401
|
|
72
|
+
linalg, # noqa: E402, F401
|
|
73
|
+
random, # noqa: E402, F401
|
|
74
|
+
stats, # noqa: E402, F401
|
|
75
|
+
)
|
|
76
|
+
from flopscope._connection import get_connection # noqa: E402
|
|
77
|
+
from flopscope._protocol import ( # noqa: E402
|
|
78
|
+
encode_create_from_data,
|
|
79
|
+
encode_request,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# ---------------------------------------------------------------------------
|
|
83
|
+
# Registry
|
|
84
|
+
# ---------------------------------------------------------------------------
|
|
85
|
+
from flopscope._registry import ( # noqa: E402
|
|
86
|
+
BLACKLISTED,
|
|
87
|
+
FUNCTION_CATEGORIES,
|
|
88
|
+
get_category,
|
|
89
|
+
is_valid_op,
|
|
90
|
+
iter_proxyable,
|
|
91
|
+
)
|
|
92
|
+
from flopscope._registry_data import FUNCTION_CATEGORIES as _FC # noqa: E402
|
|
93
|
+
|
|
94
|
+
# ---------------------------------------------------------------------------
|
|
95
|
+
# Constants (no server round-trip needed)
|
|
96
|
+
# ---------------------------------------------------------------------------
|
|
97
|
+
|
|
98
|
+
pi: float = pi
|
|
99
|
+
e: float = e
|
|
100
|
+
inf: float = inf
|
|
101
|
+
nan: float = nan
|
|
102
|
+
newaxis = None
|
|
103
|
+
|
|
104
|
+
# ---------------------------------------------------------------------------
|
|
105
|
+
# Dtype strings (mirror numpy dtype names as plain strings)
|
|
106
|
+
# ---------------------------------------------------------------------------
|
|
107
|
+
|
|
108
|
+
float16: str = "float16"
|
|
109
|
+
float32: str = "float32"
|
|
110
|
+
float64: str = "float64"
|
|
111
|
+
int8: str = "int8"
|
|
112
|
+
int16: str = "int16"
|
|
113
|
+
int32: str = "int32"
|
|
114
|
+
int64: str = "int64"
|
|
115
|
+
uint8: str = "uint8"
|
|
116
|
+
bool_: str = "bool"
|
|
117
|
+
complex64: str = "complex64"
|
|
118
|
+
complex128: str = "complex128"
|
|
119
|
+
|
|
120
|
+
# ---------------------------------------------------------------------------
|
|
121
|
+
# Proxy factory
|
|
122
|
+
# ---------------------------------------------------------------------------
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _make_proxy(op_name: str):
|
|
126
|
+
"""Create a proxy function that dispatches *op_name* to the server."""
|
|
127
|
+
|
|
128
|
+
def proxy(*args: Any, **kwargs: Any):
|
|
129
|
+
conn = get_connection()
|
|
130
|
+
encoded_args = [_encode_arg(a) for a in args]
|
|
131
|
+
encoded_kwargs = {k: _encode_arg(v) for k, v in kwargs.items()}
|
|
132
|
+
resp = conn.send_recv(
|
|
133
|
+
encode_request(op_name, args=encoded_args, kwargs=encoded_kwargs)
|
|
134
|
+
)
|
|
135
|
+
return _result_from_response(resp)
|
|
136
|
+
|
|
137
|
+
proxy.__name__ = op_name
|
|
138
|
+
proxy.__qualname__ = op_name
|
|
139
|
+
return proxy
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
# ---------------------------------------------------------------------------
|
|
143
|
+
# Special-case: array()
|
|
144
|
+
# ---------------------------------------------------------------------------
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _flatten(obj):
|
|
148
|
+
"""Recursively flatten a nested list/tuple and return ``(flat, shape)``."""
|
|
149
|
+
if not isinstance(obj, (list, tuple)):
|
|
150
|
+
return [obj], ()
|
|
151
|
+
if len(obj) == 0:
|
|
152
|
+
return [], (0,)
|
|
153
|
+
first_flat, inner_shape = _flatten(obj[0])
|
|
154
|
+
flat = list(first_flat)
|
|
155
|
+
for item in obj[1:]:
|
|
156
|
+
item_flat, item_shape = _flatten(item)
|
|
157
|
+
if item_shape != inner_shape:
|
|
158
|
+
raise ValueError(
|
|
159
|
+
f"Inhomogeneous shape: expected inner shape {inner_shape}, "
|
|
160
|
+
f"got {item_shape}"
|
|
161
|
+
)
|
|
162
|
+
flat.extend(item_flat)
|
|
163
|
+
return flat, (len(obj),) + inner_shape
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _infer_dtype(values):
|
|
167
|
+
"""Infer a dtype string from a list of Python scalars."""
|
|
168
|
+
# Use builtins.any/all to avoid collision with the proxy functions
|
|
169
|
+
# that shadow these names at module level.
|
|
170
|
+
_any = builtins.any
|
|
171
|
+
_all = builtins.all
|
|
172
|
+
has_float = _any(isinstance(v, float) for v in values)
|
|
173
|
+
has_complex = _any(isinstance(v, complex) for v in values)
|
|
174
|
+
if has_complex:
|
|
175
|
+
return "complex128"
|
|
176
|
+
if has_float:
|
|
177
|
+
return "float64"
|
|
178
|
+
if _all(isinstance(v, bool) for v in values):
|
|
179
|
+
return "bool"
|
|
180
|
+
if _all(isinstance(v, int) for v in values):
|
|
181
|
+
return "int64"
|
|
182
|
+
return "float64" # mixed or float values
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def array(object, dtype=None, **kwargs):
|
|
186
|
+
"""Create a remote array from a Python list, tuple, or existing RemoteArray.
|
|
187
|
+
|
|
188
|
+
Parameters
|
|
189
|
+
----------
|
|
190
|
+
object:
|
|
191
|
+
Data to create the array from. May be a nested list/tuple of
|
|
192
|
+
numbers or an existing :class:`RemoteArray`.
|
|
193
|
+
dtype:
|
|
194
|
+
Optional dtype string (e.g. ``"float64"``). Inferred from data
|
|
195
|
+
if not given.
|
|
196
|
+
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
RemoteArray
|
|
200
|
+
A new remote array on the server.
|
|
201
|
+
"""
|
|
202
|
+
if isinstance(object, RemoteArray):
|
|
203
|
+
if dtype is None:
|
|
204
|
+
return object
|
|
205
|
+
# dtype cast: dispatch to server
|
|
206
|
+
conn = get_connection()
|
|
207
|
+
resp = conn.send_recv(
|
|
208
|
+
encode_request("astype", args=[{"__handle__": object.handle_id}, dtype])
|
|
209
|
+
)
|
|
210
|
+
return _result_from_response(resp)
|
|
211
|
+
|
|
212
|
+
if isinstance(object, (list, tuple)):
|
|
213
|
+
flat, shape = _flatten(object)
|
|
214
|
+
if not flat:
|
|
215
|
+
# Empty array
|
|
216
|
+
dtype_str = dtype if isinstance(dtype, str) else "float64"
|
|
217
|
+
conn = get_connection()
|
|
218
|
+
resp = conn.send_recv(encode_create_from_data(b"", list(shape), dtype_str))
|
|
219
|
+
return _result_from_response(resp)
|
|
220
|
+
|
|
221
|
+
dtype_str = dtype if isinstance(dtype, str) else (dtype or _infer_dtype(flat))
|
|
222
|
+
info = _DTYPE_INFO.get(dtype_str)
|
|
223
|
+
if info is None:
|
|
224
|
+
raise TypeError(f"Unsupported dtype: {dtype_str!r}")
|
|
225
|
+
fmt_char, _ = info
|
|
226
|
+
|
|
227
|
+
# Complex types: split each value into (real, imag) pairs
|
|
228
|
+
if dtype_str in ("complex64", "complex128"):
|
|
229
|
+
expanded = []
|
|
230
|
+
for v in flat:
|
|
231
|
+
c = complex(v)
|
|
232
|
+
expanded.extend([c.real, c.imag])
|
|
233
|
+
flat = expanded
|
|
234
|
+
fmt_char = "f" if dtype_str == "complex64" else "d"
|
|
235
|
+
data = struct.pack(f"<{len(flat)}{fmt_char}", *flat)
|
|
236
|
+
else:
|
|
237
|
+
data = struct.pack(f"<{len(flat)}{fmt_char}", *flat)
|
|
238
|
+
|
|
239
|
+
conn = get_connection()
|
|
240
|
+
resp = conn.send_recv(encode_create_from_data(data, list(shape), dtype_str))
|
|
241
|
+
return _result_from_response(resp)
|
|
242
|
+
|
|
243
|
+
if isinstance(object, (int, float, complex)):
|
|
244
|
+
# Scalar -> 0-d array
|
|
245
|
+
if isinstance(object, complex) and dtype is None:
|
|
246
|
+
dtype_str = "complex128"
|
|
247
|
+
else:
|
|
248
|
+
dtype_str = dtype if isinstance(dtype, str) else "float64"
|
|
249
|
+
info = _DTYPE_INFO.get(dtype_str)
|
|
250
|
+
if info is None:
|
|
251
|
+
raise TypeError(f"Unsupported dtype: {dtype_str!r}")
|
|
252
|
+
fmt_char, _ = info
|
|
253
|
+
|
|
254
|
+
if dtype_str in ("complex64", "complex128"):
|
|
255
|
+
c = complex(object)
|
|
256
|
+
pack_fmt = "f" if dtype_str == "complex64" else "d"
|
|
257
|
+
data = struct.pack(f"<2{pack_fmt}", c.real, c.imag)
|
|
258
|
+
else:
|
|
259
|
+
data = struct.pack(f"<1{fmt_char}", object)
|
|
260
|
+
conn = get_connection()
|
|
261
|
+
resp = conn.send_recv(encode_create_from_data(data, [], dtype_str))
|
|
262
|
+
return _result_from_response(resp)
|
|
263
|
+
|
|
264
|
+
raise TypeError(
|
|
265
|
+
f"Cannot create array from {type(object).__name__}. "
|
|
266
|
+
f"Expected list, tuple, int, float, or RemoteArray."
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
# ---------------------------------------------------------------------------
|
|
271
|
+
# Special-case: einsum()
|
|
272
|
+
# ---------------------------------------------------------------------------
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def einsum(subscripts, *operands, **kwargs):
|
|
276
|
+
"""Einstein summation on remote arrays.
|
|
277
|
+
|
|
278
|
+
Parameters
|
|
279
|
+
----------
|
|
280
|
+
subscripts:
|
|
281
|
+
Subscript string (e.g. ``"ij,jk->ik"``).
|
|
282
|
+
*operands:
|
|
283
|
+
Input :class:`RemoteArray` objects.
|
|
284
|
+
**kwargs:
|
|
285
|
+
Additional keyword arguments forwarded to the server.
|
|
286
|
+
|
|
287
|
+
Returns
|
|
288
|
+
-------
|
|
289
|
+
RemoteArray
|
|
290
|
+
Result of the einsum operation.
|
|
291
|
+
"""
|
|
292
|
+
conn = get_connection()
|
|
293
|
+
encoded_args = [subscripts] + [_encode_arg(op) for op in operands]
|
|
294
|
+
encoded_kwargs = {k: _encode_arg(v) for k, v in kwargs.items()}
|
|
295
|
+
resp = conn.send_recv(
|
|
296
|
+
encode_request("einsum", args=encoded_args, kwargs=encoded_kwargs)
|
|
297
|
+
)
|
|
298
|
+
return _result_from_response(resp)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
# ---------------------------------------------------------------------------
|
|
302
|
+
# Auto-generate proxy functions for all non-blacklisted top-level ops
|
|
303
|
+
# ---------------------------------------------------------------------------
|
|
304
|
+
|
|
305
|
+
# Functions that are special-cased above and should not be overwritten.
|
|
306
|
+
_SPECIAL_CASED = frozenset({"array", "einsum"})
|
|
307
|
+
|
|
308
|
+
# Functions that belong to submodules (contain a dot) are handled by the
|
|
309
|
+
# submodule packages themselves.
|
|
310
|
+
_generated_proxies: list[str] = []
|
|
311
|
+
for _op_name in iter_proxyable():
|
|
312
|
+
if "." in _op_name:
|
|
313
|
+
continue # submodule function
|
|
314
|
+
if _op_name in _SPECIAL_CASED:
|
|
315
|
+
continue
|
|
316
|
+
globals()[_op_name] = _make_proxy(_op_name)
|
|
317
|
+
_generated_proxies.append(_op_name)
|
|
318
|
+
|
|
319
|
+
del _op_name # clean up loop variable
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
# ---------------------------------------------------------------------------
|
|
323
|
+
# Module-level __getattr__ for blacklisted / unknown names
|
|
324
|
+
# ---------------------------------------------------------------------------
|
|
325
|
+
|
|
326
|
+
# We import the factory but define the function inline so we can also
|
|
327
|
+
# check against names that are already defined in the module namespace.
|
|
328
|
+
|
|
329
|
+
from flopscope._getattr import make_module_getattr as _make_module_getattr # noqa: E402
|
|
330
|
+
|
|
331
|
+
_module_getattr = _make_module_getattr("", "flopscope")
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def __getattr__(name: str):
|
|
335
|
+
return _module_getattr(name)
|
flopscope/_budget.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
"""Client-side BudgetContext proxy that delegates to the flopscope server."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from flopscope._connection import get_connection
|
|
6
|
+
from flopscope._protocol import (
|
|
7
|
+
encode_budget_close,
|
|
8
|
+
encode_budget_open,
|
|
9
|
+
encode_budget_status,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
# Module-level guard: only one BudgetContext can be active at a time.
|
|
13
|
+
_active_context = None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OpRecord:
|
|
17
|
+
"""Record of a single operation's FLOP cost.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
op_name:
|
|
22
|
+
Name of the operation (e.g. ``"dot"``).
|
|
23
|
+
flop_cost:
|
|
24
|
+
FLOPs charged for this operation.
|
|
25
|
+
cumulative:
|
|
26
|
+
Total FLOPs used after this operation.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, op_name: str, flop_cost: int, cumulative: int) -> None:
|
|
30
|
+
self.op_name = op_name
|
|
31
|
+
self.flop_cost = flop_cost
|
|
32
|
+
self.cumulative = cumulative
|
|
33
|
+
|
|
34
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
35
|
+
return (
|
|
36
|
+
f"OpRecord(op_name={self.op_name!r}, "
|
|
37
|
+
f"flop_cost={self.flop_cost}, cumulative={self.cumulative})"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class BudgetContext:
|
|
42
|
+
"""Context manager that opens/closes a FLOP budget on the server.
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
flop_budget:
|
|
47
|
+
Maximum FLOPs allowed within this context.
|
|
48
|
+
flop_multiplier:
|
|
49
|
+
Scaling factor applied to each operation's raw FLOP count before
|
|
50
|
+
it is charged against the budget. Defaults to ``1.0``.
|
|
51
|
+
quiet:
|
|
52
|
+
If ``True``, suppress informational output. Defaults to ``False``.
|
|
53
|
+
namespace:
|
|
54
|
+
Optional label for grouping budget records.
|
|
55
|
+
|
|
56
|
+
Example
|
|
57
|
+
-------
|
|
58
|
+
>>> with BudgetContext(flop_budget=1_000_000) as ctx:
|
|
59
|
+
... result = flopscope.dot(a, b)
|
|
60
|
+
... print(ctx.summary())
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
flop_budget: int,
|
|
66
|
+
flop_multiplier: float = 1.0,
|
|
67
|
+
quiet: bool = False,
|
|
68
|
+
namespace: str | None = None,
|
|
69
|
+
) -> None:
|
|
70
|
+
self._flop_budget = flop_budget
|
|
71
|
+
self._flop_multiplier = flop_multiplier
|
|
72
|
+
self._quiet = quiet
|
|
73
|
+
self._namespace = namespace
|
|
74
|
+
self._flops_used: int = 0
|
|
75
|
+
self._close_summary: str | None = None
|
|
76
|
+
self._is_open: bool = False
|
|
77
|
+
self._previous_context = None
|
|
78
|
+
|
|
79
|
+
# ------------------------------------------------------------------
|
|
80
|
+
# Properties
|
|
81
|
+
# ------------------------------------------------------------------
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def flop_budget(self) -> int:
|
|
85
|
+
"""Maximum FLOP allowance for this context."""
|
|
86
|
+
return self._flop_budget
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def flops_used(self) -> int:
|
|
90
|
+
"""FLOPs consumed so far (cached locally, updated from server responses)."""
|
|
91
|
+
return self._flops_used
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def flops_remaining(self) -> int:
|
|
95
|
+
"""FLOPs remaining in the budget (``budget - used``)."""
|
|
96
|
+
return self._flop_budget - self._flops_used
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def flop_multiplier(self) -> float:
|
|
100
|
+
"""FLOP scaling multiplier."""
|
|
101
|
+
return self._flop_multiplier
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def quiet(self) -> bool:
|
|
105
|
+
"""Whether informational output is suppressed."""
|
|
106
|
+
return self._quiet
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def namespace(self) -> str | None:
|
|
110
|
+
"""Optional namespace label for this context."""
|
|
111
|
+
return self._namespace
|
|
112
|
+
|
|
113
|
+
# ------------------------------------------------------------------
|
|
114
|
+
# Internal helpers
|
|
115
|
+
# ------------------------------------------------------------------
|
|
116
|
+
|
|
117
|
+
def _update_budget(self, budget_info: dict) -> None:
|
|
118
|
+
"""Update the local ``flops_used`` cache from a server-response dict.
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
budget_info:
|
|
123
|
+
Dict that may contain a ``"flops_used"`` key. Missing key is
|
|
124
|
+
silently ignored.
|
|
125
|
+
"""
|
|
126
|
+
if "flops_used" in budget_info:
|
|
127
|
+
self._flops_used = int(budget_info["flops_used"])
|
|
128
|
+
|
|
129
|
+
# ------------------------------------------------------------------
|
|
130
|
+
# Public API
|
|
131
|
+
# ------------------------------------------------------------------
|
|
132
|
+
|
|
133
|
+
def summary(self) -> str:
|
|
134
|
+
"""Query the server for current budget status and return a formatted string.
|
|
135
|
+
|
|
136
|
+
Also updates the local ``flops_used`` cache.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
str
|
|
141
|
+
Human-readable summary of budget usage.
|
|
142
|
+
"""
|
|
143
|
+
conn = get_connection()
|
|
144
|
+
response = conn.send_recv(encode_budget_status())
|
|
145
|
+
# Budget status is nested inside "result" key
|
|
146
|
+
result = response.get("result", {})
|
|
147
|
+
self._update_budget(result)
|
|
148
|
+
budget = result.get("flop_budget", self._flop_budget)
|
|
149
|
+
used = self._flops_used
|
|
150
|
+
remaining = int(budget) - used
|
|
151
|
+
return f"BudgetContext: {used}/{budget} FLOPs used ({remaining} remaining)"
|
|
152
|
+
|
|
153
|
+
# ------------------------------------------------------------------
|
|
154
|
+
# Decorator support
|
|
155
|
+
# ------------------------------------------------------------------
|
|
156
|
+
|
|
157
|
+
def __call__(self, func):
|
|
158
|
+
"""Use BudgetContext as a decorator."""
|
|
159
|
+
import functools
|
|
160
|
+
|
|
161
|
+
@functools.wraps(func)
|
|
162
|
+
def wrapper(*args, **kwargs):
|
|
163
|
+
with self:
|
|
164
|
+
return func(*args, **kwargs)
|
|
165
|
+
|
|
166
|
+
return wrapper
|
|
167
|
+
|
|
168
|
+
# ------------------------------------------------------------------
|
|
169
|
+
# Context manager protocol
|
|
170
|
+
# ------------------------------------------------------------------
|
|
171
|
+
|
|
172
|
+
def __enter__(self) -> BudgetContext:
|
|
173
|
+
"""Open the budget on the server and update the local cache."""
|
|
174
|
+
global _active_context
|
|
175
|
+
if _active_context is not None and _active_context is not _global_default:
|
|
176
|
+
raise RuntimeError(
|
|
177
|
+
"Nested BudgetContext is not supported. "
|
|
178
|
+
"Only one context can be active at a time."
|
|
179
|
+
)
|
|
180
|
+
self._previous_context = _active_context
|
|
181
|
+
conn = get_connection()
|
|
182
|
+
response = conn.send_recv(
|
|
183
|
+
encode_budget_open(self._flop_budget, self._flop_multiplier)
|
|
184
|
+
)
|
|
185
|
+
self._update_budget(response)
|
|
186
|
+
self._is_open = True
|
|
187
|
+
_active_context = self
|
|
188
|
+
return self
|
|
189
|
+
|
|
190
|
+
def __exit__(self, *args: object) -> None:
|
|
191
|
+
"""Close the budget on the server and store the close summary."""
|
|
192
|
+
global _active_context
|
|
193
|
+
if self._is_open:
|
|
194
|
+
conn = get_connection()
|
|
195
|
+
response = conn.send_recv(encode_budget_close())
|
|
196
|
+
self._update_budget(response)
|
|
197
|
+
self._close_summary = (
|
|
198
|
+
f"BudgetContext closed: {self._flops_used}/{self._flop_budget} "
|
|
199
|
+
f"FLOPs used"
|
|
200
|
+
)
|
|
201
|
+
self._is_open = False
|
|
202
|
+
_accumulator.record(self)
|
|
203
|
+
_active_context = self._previous_context
|
|
204
|
+
|
|
205
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
206
|
+
return (
|
|
207
|
+
f"BudgetContext(flop_budget={self._flop_budget}, "
|
|
208
|
+
f"flops_used={self._flops_used}, "
|
|
209
|
+
f"flop_multiplier={self._flop_multiplier})"
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
# ------------------------------------------------------------------
|
|
214
|
+
# Accumulator
|
|
215
|
+
# ------------------------------------------------------------------
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class NamespaceRecord:
|
|
219
|
+
"""Snapshot of a BudgetContext's state at close time."""
|
|
220
|
+
|
|
221
|
+
def __init__(self, namespace, flop_budget, flops_used):
|
|
222
|
+
self.namespace = namespace
|
|
223
|
+
self.flop_budget = flop_budget
|
|
224
|
+
self.flops_used = flops_used
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class BudgetAccumulator:
|
|
228
|
+
"""Collects budget records across multiple BudgetContext sessions."""
|
|
229
|
+
|
|
230
|
+
def __init__(self):
|
|
231
|
+
self._records = []
|
|
232
|
+
|
|
233
|
+
def record(self, ctx):
|
|
234
|
+
self._records.append(
|
|
235
|
+
NamespaceRecord(
|
|
236
|
+
namespace=ctx.namespace,
|
|
237
|
+
flop_budget=ctx.flop_budget,
|
|
238
|
+
flops_used=ctx.flops_used,
|
|
239
|
+
)
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
def get_data(self, by_namespace=False):
|
|
243
|
+
total_budget = sum(r.flop_budget for r in self._records)
|
|
244
|
+
total_used = sum(r.flops_used for r in self._records)
|
|
245
|
+
result = {
|
|
246
|
+
"flop_budget": total_budget,
|
|
247
|
+
"flops_used": total_used,
|
|
248
|
+
"flops_remaining": total_budget - total_used,
|
|
249
|
+
"operations": {},
|
|
250
|
+
}
|
|
251
|
+
if by_namespace:
|
|
252
|
+
by_ns = {}
|
|
253
|
+
for r in self._records:
|
|
254
|
+
ns = r.namespace
|
|
255
|
+
if ns not in by_ns:
|
|
256
|
+
by_ns[ns] = {"flop_budget": 0, "flops_used": 0, "operations": {}}
|
|
257
|
+
by_ns[ns]["flop_budget"] += r.flop_budget
|
|
258
|
+
by_ns[ns]["flops_used"] += r.flops_used
|
|
259
|
+
result["by_namespace"] = by_ns
|
|
260
|
+
return result
|
|
261
|
+
|
|
262
|
+
def reset(self):
|
|
263
|
+
self._records.clear()
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
_accumulator = BudgetAccumulator()
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def budget(flop_budget, flop_multiplier=1.0, quiet=False, namespace=None):
|
|
270
|
+
"""Create a BudgetContext usable as both a context manager and decorator."""
|
|
271
|
+
return BudgetContext(
|
|
272
|
+
flop_budget=flop_budget,
|
|
273
|
+
flop_multiplier=flop_multiplier,
|
|
274
|
+
quiet=quiet,
|
|
275
|
+
namespace=namespace,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def budget_summary_dict(by_namespace=False):
|
|
280
|
+
"""Return aggregated budget data across all recorded contexts."""
|
|
281
|
+
return _accumulator.get_data(by_namespace=by_namespace)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
# Note: No budget_reset() in the client — participants must not clear usage.
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
_global_default = None
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def _get_default_budget_amount():
|
|
291
|
+
import os
|
|
292
|
+
|
|
293
|
+
raw = os.environ.get("FLOPSCOPE_DEFAULT_BUDGET")
|
|
294
|
+
if raw is not None:
|
|
295
|
+
return int(float(raw))
|
|
296
|
+
return int(1e15)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def _get_global_default():
|
|
300
|
+
global _global_default, _active_context
|
|
301
|
+
if _global_default is None:
|
|
302
|
+
_global_default = BudgetContext(
|
|
303
|
+
flop_budget=_get_default_budget_amount(),
|
|
304
|
+
quiet=True,
|
|
305
|
+
namespace=None,
|
|
306
|
+
)
|
|
307
|
+
# Open it on the server
|
|
308
|
+
conn = get_connection()
|
|
309
|
+
response = conn.send_recv(
|
|
310
|
+
encode_budget_open(
|
|
311
|
+
_global_default._flop_budget, _global_default._flop_multiplier
|
|
312
|
+
)
|
|
313
|
+
)
|
|
314
|
+
_global_default._update_budget(response)
|
|
315
|
+
_global_default._is_open = True
|
|
316
|
+
_active_context = _global_default
|
|
317
|
+
return _global_default
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def _reset_global_default():
|
|
321
|
+
global _global_default, _active_context
|
|
322
|
+
if _global_default is not None and _active_context is _global_default:
|
|
323
|
+
_active_context = None
|
|
324
|
+
_global_default = None
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Client-side communications tracker."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ClientCommsTracker:
|
|
7
|
+
"""Accumulates network statistics for flopscope client requests."""
|
|
8
|
+
|
|
9
|
+
def __init__(self) -> None:
|
|
10
|
+
self._request_count: int = 0
|
|
11
|
+
self._total_round_trip_ns: int = 0
|
|
12
|
+
self._total_request_bytes: int = 0
|
|
13
|
+
self._total_response_bytes: int = 0
|
|
14
|
+
|
|
15
|
+
def record(
|
|
16
|
+
self,
|
|
17
|
+
*,
|
|
18
|
+
round_trip_ns: int,
|
|
19
|
+
request_bytes: int,
|
|
20
|
+
response_bytes: int,
|
|
21
|
+
) -> None:
|
|
22
|
+
"""Record statistics for a single request/response round trip."""
|
|
23
|
+
self._request_count += 1
|
|
24
|
+
self._total_round_trip_ns += round_trip_ns
|
|
25
|
+
self._total_request_bytes += request_bytes
|
|
26
|
+
self._total_response_bytes += response_bytes
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def request_count(self) -> int:
|
|
30
|
+
"""Total number of requests recorded."""
|
|
31
|
+
return self._request_count
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def total_round_trip_ns(self) -> int:
|
|
35
|
+
"""Sum of all round-trip times in nanoseconds."""
|
|
36
|
+
return self._total_round_trip_ns
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def total_request_bytes(self) -> int:
|
|
40
|
+
"""Sum of all request payload sizes in bytes."""
|
|
41
|
+
return self._total_request_bytes
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def total_response_bytes(self) -> int:
|
|
45
|
+
"""Sum of all response payload sizes in bytes."""
|
|
46
|
+
return self._total_response_bytes
|