flashspec 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flashspec/__init__.py +43 -0
- flashspec/bandit/__init__.py +14 -0
- flashspec/bandit/base.py +402 -0
- flashspec/bandit/oracle.py +181 -0
- flashspec/bandit/thompson.py +178 -0
- flashspec/bandit/ucb.py +175 -0
- flashspec/engine/__init__.py +15 -0
- flashspec/engine/drafter.py +247 -0
- flashspec/engine/speculative.py +257 -0
- flashspec/engine/verifier.py +205 -0
- flashspec/export/__init__.py +5 -0
- flashspec/export/onnx.py +113 -0
- flashspec/kernels/__init__.py +18 -0
- flashspec/kernels/_reference.py +196 -0
- flashspec/kernels/gather_kernel.py +136 -0
- flashspec/kernels/verify_kernel.py +228 -0
- flashspec/metrics/__init__.py +11 -0
- flashspec/metrics/acceptance.py +175 -0
- flashspec/metrics/latency.py +234 -0
- flashspec/metrics/throughput.py +249 -0
- flashspec/py.typed +0 -0
- flashspec/sampling/__init__.py +9 -0
- flashspec/sampling/rejection.py +235 -0
- flashspec/sampling/typical.py +138 -0
- flashspec/utils/__init__.py +20 -0
- flashspec/utils/config.py +159 -0
- flashspec/utils/device.py +165 -0
- flashspec/utils/logging.py +117 -0
- flashspec-0.1.0.dist-info/METADATA +331 -0
- flashspec-0.1.0.dist-info/RECORD +32 -0
- flashspec-0.1.0.dist-info/WHEEL +4 -0
- flashspec-0.1.0.dist-info/licenses/LICENSE +117 -0
flashspec/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""FlashSpec — Adaptive speculative-decoding inference engine.
|
|
2
|
+
|
|
3
|
+
Adaptive speculative-decoding inference engine with Triton-optimised
|
|
4
|
+
verification and online bandit draft selection.
|
|
5
|
+
|
|
6
|
+
Public API surface (AGENTS.md §13.2 — do not modify without explicit approval):
|
|
7
|
+
|
|
8
|
+
flashspec.SpeculativeEngine
|
|
9
|
+
flashspec.GenerationResult
|
|
10
|
+
flashspec.FlashSpecConfig
|
|
11
|
+
flashspec.BanditConfig
|
|
12
|
+
flashspec.SamplingConfig
|
|
13
|
+
flashspec.MetricsConfig
|
|
14
|
+
flashspec.register (draft model decorator)
|
|
15
|
+
flashspec.get_drafter
|
|
16
|
+
flashspec.list_drafters
|
|
17
|
+
|
|
18
|
+
References
|
|
19
|
+
----------
|
|
20
|
+
.. [1] Leviathan et al. (2023), "Fast Inference from Transformers via
|
|
21
|
+
Speculative Decoding", arXiv:2211.17192.
|
|
22
|
+
.. [2] Myet (2025), "FlashSpec: Adaptive Speculative Decoding with Online
|
|
23
|
+
Bandit Draft Selection and Triton-Optimised Verification".
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from flashspec.engine.drafter import get_drafter, list_drafters, register
|
|
27
|
+
from flashspec.engine.speculative import GenerationResult, SpeculativeEngine
|
|
28
|
+
from flashspec.utils.config import BanditConfig, FlashSpecConfig, MetricsConfig, SamplingConfig
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"BanditConfig",
|
|
32
|
+
"FlashSpecConfig",
|
|
33
|
+
"GenerationResult",
|
|
34
|
+
"MetricsConfig",
|
|
35
|
+
"SamplingConfig",
|
|
36
|
+
"SpeculativeEngine",
|
|
37
|
+
"get_drafter",
|
|
38
|
+
"list_drafters",
|
|
39
|
+
"register",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
__version__ = "0.1.0"
|
|
43
|
+
__author__ = "Min Htet Myet (Mattral)"
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Online bandit draft selector sub-package."""
|
|
2
|
+
|
|
3
|
+
from flashspec.bandit.base import ArmStats, DraftSelector
|
|
4
|
+
from flashspec.bandit.oracle import OracleSelector
|
|
5
|
+
from flashspec.bandit.thompson import ThompsonSelector
|
|
6
|
+
from flashspec.bandit.ucb import UCB1Selector
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"ArmStats",
|
|
10
|
+
"DraftSelector",
|
|
11
|
+
"OracleSelector",
|
|
12
|
+
"ThompsonSelector",
|
|
13
|
+
"UCB1Selector",
|
|
14
|
+
]
|
flashspec/bandit/base.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
1
|
+
"""Abstract base class for online bandit draft selectors.
|
|
2
|
+
|
|
3
|
+
All concrete selectors (UCB1, Thompson, Oracle) inherit from ``DraftSelector``
|
|
4
|
+
and must honour its JSON serialisation and thread-safety contracts.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import threading
|
|
11
|
+
from abc import ABC, abstractmethod
|
|
12
|
+
from collections import deque
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
__all__ = ["DraftSelector", "ArmStats"]
|
|
17
|
+
|
|
18
|
+
# ── Value object for per-arm statistics ───────────────────────────────────────
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass(slots=True, frozen=False)
|
|
22
|
+
class ArmStats:
|
|
23
|
+
"""Mutable per-arm statistics used by bandit selectors.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
n_pulls : int
|
|
28
|
+
Total number of times this arm has been selected.
|
|
29
|
+
n_accepted : int
|
|
30
|
+
Total number of accepted tokens attributed to this arm.
|
|
31
|
+
window_accepts : deque[int]
|
|
32
|
+
Rolling window of per-round accept counts (1 or 0) for windowed stats.
|
|
33
|
+
window_size : int
|
|
34
|
+
Maximum size of the rolling window. 0 disables windowing.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
n_pulls: int = 0
|
|
38
|
+
n_accepted: int = 0
|
|
39
|
+
window_accepts: deque[int] = field(default_factory=deque)
|
|
40
|
+
window_size: int = 500
|
|
41
|
+
|
|
42
|
+
def record(self, accepted: int) -> None:
|
|
43
|
+
"""Record the outcome of one round for this arm.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
accepted : int
|
|
48
|
+
Number of tokens accepted in this round (typically 0 or 1).
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
None
|
|
53
|
+
|
|
54
|
+
Notes
|
|
55
|
+
-----
|
|
56
|
+
When ``window_size > 0`` the oldest entry is evicted once the window
|
|
57
|
+
is full, so ``mean_accept_rate`` reflects only the most recent
|
|
58
|
+
``window_size`` rounds.
|
|
59
|
+
|
|
60
|
+
Examples
|
|
61
|
+
--------
|
|
62
|
+
>>> stats = ArmStats(window_size=100)
|
|
63
|
+
>>> stats.record(accepted=1)
|
|
64
|
+
>>> stats.n_pulls
|
|
65
|
+
1
|
|
66
|
+
"""
|
|
67
|
+
self.n_pulls += 1
|
|
68
|
+
self.n_accepted += accepted
|
|
69
|
+
if self.window_size > 0:
|
|
70
|
+
self.window_accepts.append(accepted)
|
|
71
|
+
if len(self.window_accepts) > self.window_size:
|
|
72
|
+
self.window_accepts.popleft()
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def mean_accept_rate(self) -> float:
|
|
76
|
+
"""Mean acceptance rate, optionally windowed.
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
float
|
|
81
|
+
Windowed mean if ``window_size > 0`` and there are observations,
|
|
82
|
+
else global mean, else 0.0.
|
|
83
|
+
|
|
84
|
+
Notes
|
|
85
|
+
-----
|
|
86
|
+
When windowing is enabled (``window_size > 0``) the rate reflects
|
|
87
|
+
only the last ``window_size`` rounds, allowing the bandit to track
|
|
88
|
+
non-stationary acceptance distributions.
|
|
89
|
+
|
|
90
|
+
Examples
|
|
91
|
+
--------
|
|
92
|
+
>>> stats = ArmStats(window_size=0)
|
|
93
|
+
>>> stats.record(1); stats.record(0)
|
|
94
|
+
>>> stats.mean_accept_rate
|
|
95
|
+
0.5
|
|
96
|
+
"""
|
|
97
|
+
if self.window_size > 0 and self.window_accepts:
|
|
98
|
+
return sum(self.window_accepts) / len(self.window_accepts)
|
|
99
|
+
if self.n_pulls > 0:
|
|
100
|
+
return self.n_accepted / self.n_pulls
|
|
101
|
+
return 0.0
|
|
102
|
+
|
|
103
|
+
def to_dict(self) -> dict[str, Any]:
|
|
104
|
+
"""Serialise to a JSON-compatible dict.
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
dict[str, Any]
|
|
109
|
+
Dictionary with keys ``n_pulls``, ``n_accepted``,
|
|
110
|
+
``window_accepts``, and ``window_size``.
|
|
111
|
+
|
|
112
|
+
Notes
|
|
113
|
+
-----
|
|
114
|
+
The returned dict can be passed directly to :meth:`from_dict` to
|
|
115
|
+
reconstruct an identical ``ArmStats`` instance.
|
|
116
|
+
|
|
117
|
+
Examples
|
|
118
|
+
--------
|
|
119
|
+
>>> stats = ArmStats(n_pulls=5, n_accepted=3, window_size=10)
|
|
120
|
+
>>> d = stats.to_dict()
|
|
121
|
+
>>> d["n_pulls"]
|
|
122
|
+
5
|
|
123
|
+
"""
|
|
124
|
+
return {
|
|
125
|
+
"n_pulls": self.n_pulls,
|
|
126
|
+
"n_accepted": self.n_accepted,
|
|
127
|
+
"window_accepts": list(self.window_accepts),
|
|
128
|
+
"window_size": self.window_size,
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
@classmethod
|
|
132
|
+
def from_dict(cls, d: dict[str, Any]) -> "ArmStats":
|
|
133
|
+
"""Deserialise from a dict produced by :meth:`to_dict`.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
d : dict[str, Any]
|
|
138
|
+
Dictionary as returned by :meth:`to_dict`.
|
|
139
|
+
|
|
140
|
+
Returns
|
|
141
|
+
-------
|
|
142
|
+
ArmStats
|
|
143
|
+
Reconstructed instance with identical statistics.
|
|
144
|
+
|
|
145
|
+
Notes
|
|
146
|
+
-----
|
|
147
|
+
The ``window_accepts`` deque is reconstructed with the original
|
|
148
|
+
``window_size`` as its ``maxlen``.
|
|
149
|
+
|
|
150
|
+
Examples
|
|
151
|
+
--------
|
|
152
|
+
>>> stats = ArmStats(window_size=50)
|
|
153
|
+
>>> stats.record(1)
|
|
154
|
+
>>> restored = ArmStats.from_dict(stats.to_dict())
|
|
155
|
+
>>> restored.n_pulls == stats.n_pulls
|
|
156
|
+
True
|
|
157
|
+
"""
|
|
158
|
+
obj = cls(
|
|
159
|
+
n_pulls=d["n_pulls"],
|
|
160
|
+
n_accepted=d["n_accepted"],
|
|
161
|
+
window_size=d["window_size"],
|
|
162
|
+
)
|
|
163
|
+
obj.window_accepts = deque(
|
|
164
|
+
d["window_accepts"], maxlen=d["window_size"] or None
|
|
165
|
+
)
|
|
166
|
+
return obj
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# ── Abstract selector ─────────────────────────────────────────────────────────
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class DraftSelector(ABC):
|
|
173
|
+
"""Abstract base class for online bandit draft-model selectors.
|
|
174
|
+
|
|
175
|
+
Subclasses implement :meth:`select` and :meth:`update`. All methods are
|
|
176
|
+
thread-safe via a per-instance ``threading.Lock``.
|
|
177
|
+
|
|
178
|
+
Parameters
|
|
179
|
+
----------
|
|
180
|
+
n_arms : int
|
|
181
|
+
Number of draft-model arms.
|
|
182
|
+
window_size : int
|
|
183
|
+
Rolling window size for acceptance statistics (0 = disabled).
|
|
184
|
+
|
|
185
|
+
Raises
|
|
186
|
+
------
|
|
187
|
+
ValueError
|
|
188
|
+
If ``n_arms`` < 1 or ``window_size`` < 0.
|
|
189
|
+
|
|
190
|
+
Notes
|
|
191
|
+
-----
|
|
192
|
+
The selector maintains one :class:`ArmStats` object per arm.
|
|
193
|
+
The internal round counter ``t`` counts total calls to :meth:`update`.
|
|
194
|
+
All public methods acquire ``self._lock`` before mutating state so that
|
|
195
|
+
multiple generation workers can share a single selector safely.
|
|
196
|
+
|
|
197
|
+
Examples
|
|
198
|
+
--------
|
|
199
|
+
>>> selector = UCB1Selector(n_arms=3, window_size=200)
|
|
200
|
+
>>> arm = selector.select()
|
|
201
|
+
>>> selector.update(arm, accepted=1)
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
def __init__(self, n_arms: int, window_size: int = 500) -> None:
|
|
205
|
+
if n_arms < 1:
|
|
206
|
+
raise ValueError(f"n_arms must be >= 1; got {n_arms}.")
|
|
207
|
+
if window_size < 0:
|
|
208
|
+
raise ValueError(f"window_size must be >= 0; got {window_size}.")
|
|
209
|
+
self._n_arms = n_arms
|
|
210
|
+
self._window_size = window_size
|
|
211
|
+
self._arms: list[ArmStats] = [
|
|
212
|
+
ArmStats(window_size=window_size) for _ in range(n_arms)
|
|
213
|
+
]
|
|
214
|
+
self._t: int = 0
|
|
215
|
+
self._lock = threading.Lock()
|
|
216
|
+
|
|
217
|
+
# ── Public interface ───────────────────────────────────────────────────
|
|
218
|
+
|
|
219
|
+
@property
|
|
220
|
+
def n_arms(self) -> int:
|
|
221
|
+
"""Number of arms.
|
|
222
|
+
|
|
223
|
+
Returns
|
|
224
|
+
-------
|
|
225
|
+
int
|
|
226
|
+
Count of available draft-model arms.
|
|
227
|
+
|
|
228
|
+
Notes
|
|
229
|
+
-----
|
|
230
|
+
Fixed at construction time; cannot be changed after initialisation.
|
|
231
|
+
|
|
232
|
+
Examples
|
|
233
|
+
--------
|
|
234
|
+
>>> selector = UCB1Selector(n_arms=3)
|
|
235
|
+
>>> selector.n_arms
|
|
236
|
+
3
|
|
237
|
+
"""
|
|
238
|
+
return self._n_arms
|
|
239
|
+
|
|
240
|
+
@property
|
|
241
|
+
def t(self) -> int:
|
|
242
|
+
"""Total rounds elapsed (equal to the number of :meth:`update` calls).
|
|
243
|
+
|
|
244
|
+
Returns
|
|
245
|
+
-------
|
|
246
|
+
int
|
|
247
|
+
Non-negative integer round counter.
|
|
248
|
+
|
|
249
|
+
Notes
|
|
250
|
+
-----
|
|
251
|
+
Resets to 0 after :meth:`reset` is called.
|
|
252
|
+
|
|
253
|
+
Examples
|
|
254
|
+
--------
|
|
255
|
+
>>> selector = UCB1Selector(n_arms=2)
|
|
256
|
+
>>> selector.update(0, accepted=1)
|
|
257
|
+
>>> selector.t
|
|
258
|
+
1
|
|
259
|
+
"""
|
|
260
|
+
return self._t
|
|
261
|
+
|
|
262
|
+
@abstractmethod
|
|
263
|
+
def select(self) -> int:
|
|
264
|
+
"""Select an arm index to pull.
|
|
265
|
+
|
|
266
|
+
Returns
|
|
267
|
+
-------
|
|
268
|
+
int
|
|
269
|
+
Index in ``[0, n_arms)``.
|
|
270
|
+
|
|
271
|
+
Notes
|
|
272
|
+
-----
|
|
273
|
+
Implementations must be thread-safe (acquire ``self._lock`` around
|
|
274
|
+
any read-modify-write on shared state).
|
|
275
|
+
|
|
276
|
+
Examples
|
|
277
|
+
--------
|
|
278
|
+
>>> arm = selector.select()
|
|
279
|
+
>>> assert 0 <= arm < selector.n_arms
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
@abstractmethod
|
|
283
|
+
def update(self, arm: int, accepted: int) -> None:
|
|
284
|
+
"""Record the outcome of pulling an arm.
|
|
285
|
+
|
|
286
|
+
Parameters
|
|
287
|
+
----------
|
|
288
|
+
arm : int
|
|
289
|
+
Index of the arm that was pulled.
|
|
290
|
+
accepted : int
|
|
291
|
+
Number of tokens accepted in this round.
|
|
292
|
+
|
|
293
|
+
Raises
|
|
294
|
+
------
|
|
295
|
+
ValueError
|
|
296
|
+
If ``arm`` is not in ``[0, n_arms)``.
|
|
297
|
+
|
|
298
|
+
Notes
|
|
299
|
+
-----
|
|
300
|
+
Increments the internal round counter ``t`` and delegates to
|
|
301
|
+
``self._arms[arm].record(accepted)``.
|
|
302
|
+
|
|
303
|
+
Examples
|
|
304
|
+
--------
|
|
305
|
+
>>> selector.update(0, accepted=1)
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
def reset(self) -> None:
|
|
309
|
+
"""Reset all arm statistics and the round counter to zero.
|
|
310
|
+
|
|
311
|
+
Returns
|
|
312
|
+
-------
|
|
313
|
+
None
|
|
314
|
+
|
|
315
|
+
Notes
|
|
316
|
+
-----
|
|
317
|
+
Intended for per-context-window resets when the prompt distribution
|
|
318
|
+
shifts and accumulated statistics are no longer representative.
|
|
319
|
+
Thread-safe: acquires ``self._lock`` before mutating state.
|
|
320
|
+
|
|
321
|
+
Examples
|
|
322
|
+
--------
|
|
323
|
+
>>> selector.reset()
|
|
324
|
+
>>> selector.t
|
|
325
|
+
0
|
|
326
|
+
"""
|
|
327
|
+
with self._lock:
|
|
328
|
+
self._arms = [
|
|
329
|
+
ArmStats(window_size=self._window_size)
|
|
330
|
+
for _ in range(self._n_arms)
|
|
331
|
+
]
|
|
332
|
+
self._t = 0
|
|
333
|
+
|
|
334
|
+
def to_json(self) -> str:
|
|
335
|
+
"""Serialise bandit state to a JSON string.
|
|
336
|
+
|
|
337
|
+
Returns
|
|
338
|
+
-------
|
|
339
|
+
str
|
|
340
|
+
Compact JSON-encoded bandit state suitable for checkpointing.
|
|
341
|
+
|
|
342
|
+
Notes
|
|
343
|
+
-----
|
|
344
|
+
Thread-safe: acquires ``self._lock`` before reading state.
|
|
345
|
+
The returned string can be passed to :meth:`from_json` on any
|
|
346
|
+
concrete subclass to reconstruct an identical instance.
|
|
347
|
+
|
|
348
|
+
Examples
|
|
349
|
+
--------
|
|
350
|
+
>>> state_json = selector.to_json()
|
|
351
|
+
>>> selector2 = UCB1Selector.from_json(state_json)
|
|
352
|
+
"""
|
|
353
|
+
with self._lock:
|
|
354
|
+
return json.dumps(self._state_dict(), separators=(",", ":"))
|
|
355
|
+
|
|
356
|
+
@classmethod
|
|
357
|
+
def from_json(cls, json_str: str) -> "DraftSelector":
|
|
358
|
+
"""Restore bandit state from a JSON string produced by :meth:`to_json`.
|
|
359
|
+
|
|
360
|
+
Parameters
|
|
361
|
+
----------
|
|
362
|
+
json_str : str
|
|
363
|
+
JSON string previously produced by :meth:`to_json`.
|
|
364
|
+
|
|
365
|
+
Returns
|
|
366
|
+
-------
|
|
367
|
+
DraftSelector
|
|
368
|
+
Restored selector instance with identical state.
|
|
369
|
+
|
|
370
|
+
Raises
|
|
371
|
+
------
|
|
372
|
+
ValueError
|
|
373
|
+
If ``json_str`` is not valid JSON or is missing required fields.
|
|
374
|
+
|
|
375
|
+
Notes
|
|
376
|
+
-----
|
|
377
|
+
Delegates to the concrete subclass's :meth:`_from_state_dict` method.
|
|
378
|
+
The subclass is determined by the ``"type"`` key in the JSON object.
|
|
379
|
+
|
|
380
|
+
Examples
|
|
381
|
+
--------
|
|
382
|
+
>>> json_str = selector.to_json()
|
|
383
|
+
>>> restored = UCB1Selector.from_json(json_str)
|
|
384
|
+
>>> restored.t == selector.t
|
|
385
|
+
True
|
|
386
|
+
"""
|
|
387
|
+
try:
|
|
388
|
+
state = json.loads(json_str)
|
|
389
|
+
except json.JSONDecodeError as exc:
|
|
390
|
+
raise ValueError(f"Invalid JSON for bandit state: {exc}") from exc
|
|
391
|
+
return cls._from_state_dict(state)
|
|
392
|
+
|
|
393
|
+
# ── Subclass hooks ─────────────────────────────────────────────────────
|
|
394
|
+
|
|
395
|
+
@abstractmethod
|
|
396
|
+
def _state_dict(self) -> dict[str, Any]:
|
|
397
|
+
"""Return a JSON-serialisable dict of all state."""
|
|
398
|
+
|
|
399
|
+
@classmethod
|
|
400
|
+
@abstractmethod
|
|
401
|
+
def _from_state_dict(cls, state: dict[str, Any]) -> "DraftSelector":
|
|
402
|
+
"""Restore an instance from a state dict."""
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
"""Oracle bandit selector — upper-bound baseline for regret experiments.
|
|
2
|
+
|
|
3
|
+
The Oracle always picks the arm with the highest *true* acceptance rate,
|
|
4
|
+
which must be supplied externally. It is used only to compute the regret
|
|
5
|
+
upper bound in experiments; it is never used in production inference.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from flashspec.bandit.base import ArmStats, DraftSelector
|
|
13
|
+
from flashspec.utils.logging import get_logger
|
|
14
|
+
|
|
15
|
+
__all__ = ["OracleSelector"]
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class OracleSelector(DraftSelector):
|
|
21
|
+
"""Oracle bandit selector that always picks the true best arm.
|
|
22
|
+
|
|
23
|
+
Requires ground-truth acceptance rates to be provided at construction
|
|
24
|
+
time and updated via :meth:`set_true_rates`. Used only in regret
|
|
25
|
+
upper-bound experiments — never in production inference.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
n_arms : int
|
|
30
|
+
Number of draft-model arms.
|
|
31
|
+
true_rates : list[float]
|
|
32
|
+
Ground-truth acceptance rate for each arm. Must have length ``n_arms``
|
|
33
|
+
with values in ``[0, 1]``.
|
|
34
|
+
window_size : int
|
|
35
|
+
Rolling window for acceptance statistics.
|
|
36
|
+
|
|
37
|
+
Raises
|
|
38
|
+
------
|
|
39
|
+
ValueError
|
|
40
|
+
If ``len(true_rates) != n_arms`` or any rate is outside ``[0, 1]``.
|
|
41
|
+
|
|
42
|
+
Notes
|
|
43
|
+
-----
|
|
44
|
+
The oracle's cumulative reward serves as the upper bound for regret
|
|
45
|
+
calculations in ``tests/unit/test_bandit.py``.
|
|
46
|
+
|
|
47
|
+
Examples
|
|
48
|
+
--------
|
|
49
|
+
>>> selector = OracleSelector(n_arms=2, true_rates=[0.6, 0.9])
|
|
50
|
+
>>> selector.select()
|
|
51
|
+
1
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
n_arms: int,
|
|
57
|
+
true_rates: list[float],
|
|
58
|
+
window_size: int = 500,
|
|
59
|
+
) -> None:
|
|
60
|
+
if len(true_rates) != n_arms:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
f"len(true_rates) must equal n_arms={n_arms}; "
|
|
63
|
+
f"got {len(true_rates)}."
|
|
64
|
+
)
|
|
65
|
+
for i, r in enumerate(true_rates):
|
|
66
|
+
if not (0.0 <= r <= 1.0):
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"true_rates[{i}]={r} is outside [0, 1]."
|
|
69
|
+
)
|
|
70
|
+
super().__init__(n_arms=n_arms, window_size=window_size)
|
|
71
|
+
self._true_rates: list[float] = list(true_rates)
|
|
72
|
+
|
|
73
|
+
def select(self) -> int:
|
|
74
|
+
"""Return the arm index with the highest true acceptance rate.
|
|
75
|
+
|
|
76
|
+
Returns
|
|
77
|
+
-------
|
|
78
|
+
int
|
|
79
|
+
Arm index in ``[0, n_arms)``.
|
|
80
|
+
|
|
81
|
+
Notes
|
|
82
|
+
-----
|
|
83
|
+
The oracle has perfect knowledge of ``true_rates`` and always picks
|
|
84
|
+
``argmax(true_rates)``. It serves as the regret upper bound in
|
|
85
|
+
experiments; it is never used in production inference.
|
|
86
|
+
|
|
87
|
+
Examples
|
|
88
|
+
--------
|
|
89
|
+
>>> OracleSelector(n_arms=2, true_rates=[0.4, 0.8]).select()
|
|
90
|
+
1
|
|
91
|
+
"""
|
|
92
|
+
with self._lock:
|
|
93
|
+
return int(max(range(self._n_arms), key=lambda k: self._true_rates[k]))
|
|
94
|
+
|
|
95
|
+
def update(self, arm: int, accepted: int) -> None:
|
|
96
|
+
"""Record outcome (used for regret tracking only; does not affect selection).
|
|
97
|
+
|
|
98
|
+
Parameters
|
|
99
|
+
----------
|
|
100
|
+
arm : int
|
|
101
|
+
Arm index that was pulled.
|
|
102
|
+
accepted : int
|
|
103
|
+
Number of accepted tokens.
|
|
104
|
+
|
|
105
|
+
Raises
|
|
106
|
+
------
|
|
107
|
+
ValueError
|
|
108
|
+
If ``arm`` is not in ``[0, n_arms)``.
|
|
109
|
+
|
|
110
|
+
Notes
|
|
111
|
+
-----
|
|
112
|
+
The oracle's selection policy is independent of observed outcomes;
|
|
113
|
+
it always selects the arm with the highest ``true_rates``. This
|
|
114
|
+
method records statistics only so that cumulative regret can be
|
|
115
|
+
computed from arm pull counts.
|
|
116
|
+
|
|
117
|
+
Examples
|
|
118
|
+
--------
|
|
119
|
+
>>> selector.update(1, accepted=1)
|
|
120
|
+
"""
|
|
121
|
+
if not (0 <= arm < self._n_arms):
|
|
122
|
+
raise ValueError(f"arm must be in [0, {self._n_arms}); got {arm}.")
|
|
123
|
+
with self._lock:
|
|
124
|
+
self._arms[arm].record(accepted)
|
|
125
|
+
self._t += 1
|
|
126
|
+
|
|
127
|
+
def set_true_rates(self, true_rates: list[float]) -> None:
|
|
128
|
+
"""Update ground-truth acceptance rates (for non-stationary experiments).
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
----------
|
|
132
|
+
true_rates : list[float]
|
|
133
|
+
New ground-truth rates. Must have the same length as ``n_arms``
|
|
134
|
+
and all values in ``[0, 1]``.
|
|
135
|
+
|
|
136
|
+
Raises
|
|
137
|
+
------
|
|
138
|
+
ValueError
|
|
139
|
+
If length or values are invalid.
|
|
140
|
+
|
|
141
|
+
Notes
|
|
142
|
+
-----
|
|
143
|
+
Thread-safe: acquires ``self._lock`` before mutating state.
|
|
144
|
+
Used in chaos tests to simulate a sudden swap of best/worst arm,
|
|
145
|
+
verifying that adaptive bandits (UCB1, Thompson) recover.
|
|
146
|
+
|
|
147
|
+
Examples
|
|
148
|
+
--------
|
|
149
|
+
>>> selector.set_true_rates([0.9, 0.4]) # swap best/worst arm
|
|
150
|
+
"""
|
|
151
|
+
if len(true_rates) != self._n_arms:
|
|
152
|
+
raise ValueError(
|
|
153
|
+
f"len(true_rates) must equal n_arms={self._n_arms}; "
|
|
154
|
+
f"got {len(true_rates)}."
|
|
155
|
+
)
|
|
156
|
+
for i, r in enumerate(true_rates):
|
|
157
|
+
if not (0.0 <= r <= 1.0):
|
|
158
|
+
raise ValueError(f"true_rates[{i}]={r} is outside [0, 1].")
|
|
159
|
+
with self._lock:
|
|
160
|
+
self._true_rates = list(true_rates)
|
|
161
|
+
|
|
162
|
+
def _state_dict(self) -> dict[str, Any]:
|
|
163
|
+
return {
|
|
164
|
+
"type": "oracle",
|
|
165
|
+
"n_arms": self._n_arms,
|
|
166
|
+
"window_size": self._window_size,
|
|
167
|
+
"true_rates": self._true_rates,
|
|
168
|
+
"t": self._t,
|
|
169
|
+
"arms": [a.to_dict() for a in self._arms],
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
@classmethod
|
|
173
|
+
def _from_state_dict(cls, state: dict[str, Any]) -> "OracleSelector":
|
|
174
|
+
obj = cls(
|
|
175
|
+
n_arms=state["n_arms"],
|
|
176
|
+
true_rates=state["true_rates"],
|
|
177
|
+
window_size=state["window_size"],
|
|
178
|
+
)
|
|
179
|
+
obj._t = state["t"]
|
|
180
|
+
obj._arms = [ArmStats.from_dict(d) for d in state["arms"]]
|
|
181
|
+
return obj
|