flashspec 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flashspec/__init__.py ADDED
@@ -0,0 +1,43 @@
1
+ """FlashSpec — Adaptive speculative-decoding inference engine.
2
+
3
+ Adaptive speculative-decoding inference engine with Triton-optimised
4
+ verification and online bandit draft selection.
5
+
6
+ Public API surface (AGENTS.md §13.2 — do not modify without explicit approval):
7
+
8
+ flashspec.SpeculativeEngine
9
+ flashspec.GenerationResult
10
+ flashspec.FlashSpecConfig
11
+ flashspec.BanditConfig
12
+ flashspec.SamplingConfig
13
+ flashspec.MetricsConfig
14
+ flashspec.register (draft model decorator)
15
+ flashspec.get_drafter
16
+ flashspec.list_drafters
17
+
18
+ References
19
+ ----------
20
+ .. [1] Leviathan et al. (2023), "Fast Inference from Transformers via
21
+ Speculative Decoding", arXiv:2211.17192.
22
+ .. [2] Myet (2025), "FlashSpec: Adaptive Speculative Decoding with Online
23
+ Bandit Draft Selection and Triton-Optimised Verification".
24
+ """
25
+
26
+ from flashspec.engine.drafter import get_drafter, list_drafters, register
27
+ from flashspec.engine.speculative import GenerationResult, SpeculativeEngine
28
+ from flashspec.utils.config import BanditConfig, FlashSpecConfig, MetricsConfig, SamplingConfig
29
+
30
+ __all__ = [
31
+ "BanditConfig",
32
+ "FlashSpecConfig",
33
+ "GenerationResult",
34
+ "MetricsConfig",
35
+ "SamplingConfig",
36
+ "SpeculativeEngine",
37
+ "get_drafter",
38
+ "list_drafters",
39
+ "register",
40
+ ]
41
+
42
+ __version__ = "0.1.0"
43
+ __author__ = "Min Htet Myet (Mattral)"
@@ -0,0 +1,14 @@
1
+ """Online bandit draft selector sub-package."""
2
+
3
+ from flashspec.bandit.base import ArmStats, DraftSelector
4
+ from flashspec.bandit.oracle import OracleSelector
5
+ from flashspec.bandit.thompson import ThompsonSelector
6
+ from flashspec.bandit.ucb import UCB1Selector
7
+
8
+ __all__ = [
9
+ "ArmStats",
10
+ "DraftSelector",
11
+ "OracleSelector",
12
+ "ThompsonSelector",
13
+ "UCB1Selector",
14
+ ]
@@ -0,0 +1,402 @@
1
+ """Abstract base class for online bandit draft selectors.
2
+
3
+ All concrete selectors (UCB1, Thompson, Oracle) inherit from ``DraftSelector``
4
+ and must honour its JSON serialisation and thread-safety contracts.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import threading
11
+ from abc import ABC, abstractmethod
12
+ from collections import deque
13
+ from dataclasses import dataclass, field
14
+ from typing import Any
15
+
16
+ __all__ = ["DraftSelector", "ArmStats"]
17
+
18
+ # ── Value object for per-arm statistics ───────────────────────────────────────
19
+
20
+
21
+ @dataclass(slots=True, frozen=False)
22
+ class ArmStats:
23
+ """Mutable per-arm statistics used by bandit selectors.
24
+
25
+ Parameters
26
+ ----------
27
+ n_pulls : int
28
+ Total number of times this arm has been selected.
29
+ n_accepted : int
30
+ Total number of accepted tokens attributed to this arm.
31
+ window_accepts : deque[int]
32
+ Rolling window of per-round accept counts (1 or 0) for windowed stats.
33
+ window_size : int
34
+ Maximum size of the rolling window. 0 disables windowing.
35
+ """
36
+
37
+ n_pulls: int = 0
38
+ n_accepted: int = 0
39
+ window_accepts: deque[int] = field(default_factory=deque)
40
+ window_size: int = 500
41
+
42
+ def record(self, accepted: int) -> None:
43
+ """Record the outcome of one round for this arm.
44
+
45
+ Parameters
46
+ ----------
47
+ accepted : int
48
+ Number of tokens accepted in this round (typically 0 or 1).
49
+
50
+ Returns
51
+ -------
52
+ None
53
+
54
+ Notes
55
+ -----
56
+ When ``window_size > 0`` the oldest entry is evicted once the window
57
+ is full, so ``mean_accept_rate`` reflects only the most recent
58
+ ``window_size`` rounds.
59
+
60
+ Examples
61
+ --------
62
+ >>> stats = ArmStats(window_size=100)
63
+ >>> stats.record(accepted=1)
64
+ >>> stats.n_pulls
65
+ 1
66
+ """
67
+ self.n_pulls += 1
68
+ self.n_accepted += accepted
69
+ if self.window_size > 0:
70
+ self.window_accepts.append(accepted)
71
+ if len(self.window_accepts) > self.window_size:
72
+ self.window_accepts.popleft()
73
+
74
+ @property
75
+ def mean_accept_rate(self) -> float:
76
+ """Mean acceptance rate, optionally windowed.
77
+
78
+ Returns
79
+ -------
80
+ float
81
+ Windowed mean if ``window_size > 0`` and there are observations,
82
+ else global mean, else 0.0.
83
+
84
+ Notes
85
+ -----
86
+ When windowing is enabled (``window_size > 0``) the rate reflects
87
+ only the last ``window_size`` rounds, allowing the bandit to track
88
+ non-stationary acceptance distributions.
89
+
90
+ Examples
91
+ --------
92
+ >>> stats = ArmStats(window_size=0)
93
+ >>> stats.record(1); stats.record(0)
94
+ >>> stats.mean_accept_rate
95
+ 0.5
96
+ """
97
+ if self.window_size > 0 and self.window_accepts:
98
+ return sum(self.window_accepts) / len(self.window_accepts)
99
+ if self.n_pulls > 0:
100
+ return self.n_accepted / self.n_pulls
101
+ return 0.0
102
+
103
+ def to_dict(self) -> dict[str, Any]:
104
+ """Serialise to a JSON-compatible dict.
105
+
106
+ Returns
107
+ -------
108
+ dict[str, Any]
109
+ Dictionary with keys ``n_pulls``, ``n_accepted``,
110
+ ``window_accepts``, and ``window_size``.
111
+
112
+ Notes
113
+ -----
114
+ The returned dict can be passed directly to :meth:`from_dict` to
115
+ reconstruct an identical ``ArmStats`` instance.
116
+
117
+ Examples
118
+ --------
119
+ >>> stats = ArmStats(n_pulls=5, n_accepted=3, window_size=10)
120
+ >>> d = stats.to_dict()
121
+ >>> d["n_pulls"]
122
+ 5
123
+ """
124
+ return {
125
+ "n_pulls": self.n_pulls,
126
+ "n_accepted": self.n_accepted,
127
+ "window_accepts": list(self.window_accepts),
128
+ "window_size": self.window_size,
129
+ }
130
+
131
+ @classmethod
132
+ def from_dict(cls, d: dict[str, Any]) -> "ArmStats":
133
+ """Deserialise from a dict produced by :meth:`to_dict`.
134
+
135
+ Parameters
136
+ ----------
137
+ d : dict[str, Any]
138
+ Dictionary as returned by :meth:`to_dict`.
139
+
140
+ Returns
141
+ -------
142
+ ArmStats
143
+ Reconstructed instance with identical statistics.
144
+
145
+ Notes
146
+ -----
147
+ The ``window_accepts`` deque is reconstructed with the original
148
+ ``window_size`` as its ``maxlen``.
149
+
150
+ Examples
151
+ --------
152
+ >>> stats = ArmStats(window_size=50)
153
+ >>> stats.record(1)
154
+ >>> restored = ArmStats.from_dict(stats.to_dict())
155
+ >>> restored.n_pulls == stats.n_pulls
156
+ True
157
+ """
158
+ obj = cls(
159
+ n_pulls=d["n_pulls"],
160
+ n_accepted=d["n_accepted"],
161
+ window_size=d["window_size"],
162
+ )
163
+ obj.window_accepts = deque(
164
+ d["window_accepts"], maxlen=d["window_size"] or None
165
+ )
166
+ return obj
167
+
168
+
169
+ # ── Abstract selector ─────────────────────────────────────────────────────────
170
+
171
+
172
+ class DraftSelector(ABC):
173
+ """Abstract base class for online bandit draft-model selectors.
174
+
175
+ Subclasses implement :meth:`select` and :meth:`update`. All methods are
176
+ thread-safe via a per-instance ``threading.Lock``.
177
+
178
+ Parameters
179
+ ----------
180
+ n_arms : int
181
+ Number of draft-model arms.
182
+ window_size : int
183
+ Rolling window size for acceptance statistics (0 = disabled).
184
+
185
+ Raises
186
+ ------
187
+ ValueError
188
+ If ``n_arms`` < 1 or ``window_size`` < 0.
189
+
190
+ Notes
191
+ -----
192
+ The selector maintains one :class:`ArmStats` object per arm.
193
+ The internal round counter ``t`` counts total calls to :meth:`update`.
194
+ All public methods acquire ``self._lock`` before mutating state so that
195
+ multiple generation workers can share a single selector safely.
196
+
197
+ Examples
198
+ --------
199
+ >>> selector = UCB1Selector(n_arms=3, window_size=200)
200
+ >>> arm = selector.select()
201
+ >>> selector.update(arm, accepted=1)
202
+ """
203
+
204
+ def __init__(self, n_arms: int, window_size: int = 500) -> None:
205
+ if n_arms < 1:
206
+ raise ValueError(f"n_arms must be >= 1; got {n_arms}.")
207
+ if window_size < 0:
208
+ raise ValueError(f"window_size must be >= 0; got {window_size}.")
209
+ self._n_arms = n_arms
210
+ self._window_size = window_size
211
+ self._arms: list[ArmStats] = [
212
+ ArmStats(window_size=window_size) for _ in range(n_arms)
213
+ ]
214
+ self._t: int = 0
215
+ self._lock = threading.Lock()
216
+
217
+ # ── Public interface ───────────────────────────────────────────────────
218
+
219
+ @property
220
+ def n_arms(self) -> int:
221
+ """Number of arms.
222
+
223
+ Returns
224
+ -------
225
+ int
226
+ Count of available draft-model arms.
227
+
228
+ Notes
229
+ -----
230
+ Fixed at construction time; cannot be changed after initialisation.
231
+
232
+ Examples
233
+ --------
234
+ >>> selector = UCB1Selector(n_arms=3)
235
+ >>> selector.n_arms
236
+ 3
237
+ """
238
+ return self._n_arms
239
+
240
+ @property
241
+ def t(self) -> int:
242
+ """Total rounds elapsed (equal to the number of :meth:`update` calls).
243
+
244
+ Returns
245
+ -------
246
+ int
247
+ Non-negative integer round counter.
248
+
249
+ Notes
250
+ -----
251
+ Resets to 0 after :meth:`reset` is called.
252
+
253
+ Examples
254
+ --------
255
+ >>> selector = UCB1Selector(n_arms=2)
256
+ >>> selector.update(0, accepted=1)
257
+ >>> selector.t
258
+ 1
259
+ """
260
+ return self._t
261
+
262
+ @abstractmethod
263
+ def select(self) -> int:
264
+ """Select an arm index to pull.
265
+
266
+ Returns
267
+ -------
268
+ int
269
+ Index in ``[0, n_arms)``.
270
+
271
+ Notes
272
+ -----
273
+ Implementations must be thread-safe (acquire ``self._lock`` around
274
+ any read-modify-write on shared state).
275
+
276
+ Examples
277
+ --------
278
+ >>> arm = selector.select()
279
+ >>> assert 0 <= arm < selector.n_arms
280
+ """
281
+
282
+ @abstractmethod
283
+ def update(self, arm: int, accepted: int) -> None:
284
+ """Record the outcome of pulling an arm.
285
+
286
+ Parameters
287
+ ----------
288
+ arm : int
289
+ Index of the arm that was pulled.
290
+ accepted : int
291
+ Number of tokens accepted in this round.
292
+
293
+ Raises
294
+ ------
295
+ ValueError
296
+ If ``arm`` is not in ``[0, n_arms)``.
297
+
298
+ Notes
299
+ -----
300
+ Increments the internal round counter ``t`` and delegates to
301
+ ``self._arms[arm].record(accepted)``.
302
+
303
+ Examples
304
+ --------
305
+ >>> selector.update(0, accepted=1)
306
+ """
307
+
308
+ def reset(self) -> None:
309
+ """Reset all arm statistics and the round counter to zero.
310
+
311
+ Returns
312
+ -------
313
+ None
314
+
315
+ Notes
316
+ -----
317
+ Intended for per-context-window resets when the prompt distribution
318
+ shifts and accumulated statistics are no longer representative.
319
+ Thread-safe: acquires ``self._lock`` before mutating state.
320
+
321
+ Examples
322
+ --------
323
+ >>> selector.reset()
324
+ >>> selector.t
325
+ 0
326
+ """
327
+ with self._lock:
328
+ self._arms = [
329
+ ArmStats(window_size=self._window_size)
330
+ for _ in range(self._n_arms)
331
+ ]
332
+ self._t = 0
333
+
334
+ def to_json(self) -> str:
335
+ """Serialise bandit state to a JSON string.
336
+
337
+ Returns
338
+ -------
339
+ str
340
+ Compact JSON-encoded bandit state suitable for checkpointing.
341
+
342
+ Notes
343
+ -----
344
+ Thread-safe: acquires ``self._lock`` before reading state.
345
+ The returned string can be passed to :meth:`from_json` on any
346
+ concrete subclass to reconstruct an identical instance.
347
+
348
+ Examples
349
+ --------
350
+ >>> state_json = selector.to_json()
351
+ >>> selector2 = UCB1Selector.from_json(state_json)
352
+ """
353
+ with self._lock:
354
+ return json.dumps(self._state_dict(), separators=(",", ":"))
355
+
356
+ @classmethod
357
+ def from_json(cls, json_str: str) -> "DraftSelector":
358
+ """Restore bandit state from a JSON string produced by :meth:`to_json`.
359
+
360
+ Parameters
361
+ ----------
362
+ json_str : str
363
+ JSON string previously produced by :meth:`to_json`.
364
+
365
+ Returns
366
+ -------
367
+ DraftSelector
368
+ Restored selector instance with identical state.
369
+
370
+ Raises
371
+ ------
372
+ ValueError
373
+ If ``json_str`` is not valid JSON or is missing required fields.
374
+
375
+ Notes
376
+ -----
377
+ Delegates to the concrete subclass's :meth:`_from_state_dict` method.
378
+ The subclass is determined by the ``"type"`` key in the JSON object.
379
+
380
+ Examples
381
+ --------
382
+ >>> json_str = selector.to_json()
383
+ >>> restored = UCB1Selector.from_json(json_str)
384
+ >>> restored.t == selector.t
385
+ True
386
+ """
387
+ try:
388
+ state = json.loads(json_str)
389
+ except json.JSONDecodeError as exc:
390
+ raise ValueError(f"Invalid JSON for bandit state: {exc}") from exc
391
+ return cls._from_state_dict(state)
392
+
393
+ # ── Subclass hooks ─────────────────────────────────────────────────────
394
+
395
+ @abstractmethod
396
+ def _state_dict(self) -> dict[str, Any]:
397
+ """Return a JSON-serialisable dict of all state."""
398
+
399
+ @classmethod
400
+ @abstractmethod
401
+ def _from_state_dict(cls, state: dict[str, Any]) -> "DraftSelector":
402
+ """Restore an instance from a state dict."""
@@ -0,0 +1,181 @@
1
+ """Oracle bandit selector — upper-bound baseline for regret experiments.
2
+
3
+ The Oracle always picks the arm with the highest *true* acceptance rate,
4
+ which must be supplied externally. It is used only to compute the regret
5
+ upper bound in experiments; it is never used in production inference.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Any
11
+
12
+ from flashspec.bandit.base import ArmStats, DraftSelector
13
+ from flashspec.utils.logging import get_logger
14
+
15
+ __all__ = ["OracleSelector"]
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ class OracleSelector(DraftSelector):
21
+ """Oracle bandit selector that always picks the true best arm.
22
+
23
+ Requires ground-truth acceptance rates to be provided at construction
24
+ time and updated via :meth:`set_true_rates`. Used only in regret
25
+ upper-bound experiments — never in production inference.
26
+
27
+ Parameters
28
+ ----------
29
+ n_arms : int
30
+ Number of draft-model arms.
31
+ true_rates : list[float]
32
+ Ground-truth acceptance rate for each arm. Must have length ``n_arms``
33
+ with values in ``[0, 1]``.
34
+ window_size : int
35
+ Rolling window for acceptance statistics.
36
+
37
+ Raises
38
+ ------
39
+ ValueError
40
+ If ``len(true_rates) != n_arms`` or any rate is outside ``[0, 1]``.
41
+
42
+ Notes
43
+ -----
44
+ The oracle's cumulative reward serves as the upper bound for regret
45
+ calculations in ``tests/unit/test_bandit.py``.
46
+
47
+ Examples
48
+ --------
49
+ >>> selector = OracleSelector(n_arms=2, true_rates=[0.6, 0.9])
50
+ >>> selector.select()
51
+ 1
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ n_arms: int,
57
+ true_rates: list[float],
58
+ window_size: int = 500,
59
+ ) -> None:
60
+ if len(true_rates) != n_arms:
61
+ raise ValueError(
62
+ f"len(true_rates) must equal n_arms={n_arms}; "
63
+ f"got {len(true_rates)}."
64
+ )
65
+ for i, r in enumerate(true_rates):
66
+ if not (0.0 <= r <= 1.0):
67
+ raise ValueError(
68
+ f"true_rates[{i}]={r} is outside [0, 1]."
69
+ )
70
+ super().__init__(n_arms=n_arms, window_size=window_size)
71
+ self._true_rates: list[float] = list(true_rates)
72
+
73
+ def select(self) -> int:
74
+ """Return the arm index with the highest true acceptance rate.
75
+
76
+ Returns
77
+ -------
78
+ int
79
+ Arm index in ``[0, n_arms)``.
80
+
81
+ Notes
82
+ -----
83
+ The oracle has perfect knowledge of ``true_rates`` and always picks
84
+ ``argmax(true_rates)``. It serves as the regret upper bound in
85
+ experiments; it is never used in production inference.
86
+
87
+ Examples
88
+ --------
89
+ >>> OracleSelector(n_arms=2, true_rates=[0.4, 0.8]).select()
90
+ 1
91
+ """
92
+ with self._lock:
93
+ return int(max(range(self._n_arms), key=lambda k: self._true_rates[k]))
94
+
95
+ def update(self, arm: int, accepted: int) -> None:
96
+ """Record outcome (used for regret tracking only; does not affect selection).
97
+
98
+ Parameters
99
+ ----------
100
+ arm : int
101
+ Arm index that was pulled.
102
+ accepted : int
103
+ Number of accepted tokens.
104
+
105
+ Raises
106
+ ------
107
+ ValueError
108
+ If ``arm`` is not in ``[0, n_arms)``.
109
+
110
+ Notes
111
+ -----
112
+ The oracle's selection policy is independent of observed outcomes;
113
+ it always selects the arm with the highest ``true_rates``. This
114
+ method records statistics only so that cumulative regret can be
115
+ computed from arm pull counts.
116
+
117
+ Examples
118
+ --------
119
+ >>> selector.update(1, accepted=1)
120
+ """
121
+ if not (0 <= arm < self._n_arms):
122
+ raise ValueError(f"arm must be in [0, {self._n_arms}); got {arm}.")
123
+ with self._lock:
124
+ self._arms[arm].record(accepted)
125
+ self._t += 1
126
+
127
+ def set_true_rates(self, true_rates: list[float]) -> None:
128
+ """Update ground-truth acceptance rates (for non-stationary experiments).
129
+
130
+ Parameters
131
+ ----------
132
+ true_rates : list[float]
133
+ New ground-truth rates. Must have the same length as ``n_arms``
134
+ and all values in ``[0, 1]``.
135
+
136
+ Raises
137
+ ------
138
+ ValueError
139
+ If length or values are invalid.
140
+
141
+ Notes
142
+ -----
143
+ Thread-safe: acquires ``self._lock`` before mutating state.
144
+ Used in chaos tests to simulate a sudden swap of best/worst arm,
145
+ verifying that adaptive bandits (UCB1, Thompson) recover.
146
+
147
+ Examples
148
+ --------
149
+ >>> selector.set_true_rates([0.9, 0.4]) # swap best/worst arm
150
+ """
151
+ if len(true_rates) != self._n_arms:
152
+ raise ValueError(
153
+ f"len(true_rates) must equal n_arms={self._n_arms}; "
154
+ f"got {len(true_rates)}."
155
+ )
156
+ for i, r in enumerate(true_rates):
157
+ if not (0.0 <= r <= 1.0):
158
+ raise ValueError(f"true_rates[{i}]={r} is outside [0, 1].")
159
+ with self._lock:
160
+ self._true_rates = list(true_rates)
161
+
162
+ def _state_dict(self) -> dict[str, Any]:
163
+ return {
164
+ "type": "oracle",
165
+ "n_arms": self._n_arms,
166
+ "window_size": self._window_size,
167
+ "true_rates": self._true_rates,
168
+ "t": self._t,
169
+ "arms": [a.to_dict() for a in self._arms],
170
+ }
171
+
172
+ @classmethod
173
+ def _from_state_dict(cls, state: dict[str, Any]) -> "OracleSelector":
174
+ obj = cls(
175
+ n_arms=state["n_arms"],
176
+ true_rates=state["true_rates"],
177
+ window_size=state["window_size"],
178
+ )
179
+ obj._t = state["t"]
180
+ obj._arms = [ArmStats.from_dict(d) for d in state["arms"]]
181
+ return obj