qortex-learning 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qortex_learning-0.1.0/.gitignore +58 -0
- qortex_learning-0.1.0/PKG-INFO +38 -0
- qortex_learning-0.1.0/README.md +18 -0
- qortex_learning-0.1.0/pyproject.toml +41 -0
- qortex_learning-0.1.0/src/qortex/learning/__init__.py +51 -0
- qortex_learning-0.1.0/src/qortex/learning/learner.py +364 -0
- qortex_learning-0.1.0/src/qortex/learning/pg_store.py +211 -0
- qortex_learning-0.1.0/src/qortex/learning/reward.py +34 -0
- qortex_learning-0.1.0/src/qortex/learning/store.py +326 -0
- qortex_learning-0.1.0/src/qortex/learning/strategy.py +162 -0
- qortex_learning-0.1.0/src/qortex/learning/types.py +112 -0
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# Python
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[cod]
|
|
4
|
+
*$py.class
|
|
5
|
+
*.so
|
|
6
|
+
.Python
|
|
7
|
+
build/
|
|
8
|
+
develop-eggs/
|
|
9
|
+
dist/
|
|
10
|
+
downloads/
|
|
11
|
+
eggs/
|
|
12
|
+
.eggs/
|
|
13
|
+
lib/
|
|
14
|
+
lib64/
|
|
15
|
+
parts/
|
|
16
|
+
sdist/
|
|
17
|
+
var/
|
|
18
|
+
wheels/
|
|
19
|
+
*.egg-info/
|
|
20
|
+
.installed.cfg
|
|
21
|
+
*.egg
|
|
22
|
+
|
|
23
|
+
# Virtual environments
|
|
24
|
+
.venv/
|
|
25
|
+
venv/
|
|
26
|
+
ENV/
|
|
27
|
+
|
|
28
|
+
# IDE
|
|
29
|
+
.idea/
|
|
30
|
+
.vscode/
|
|
31
|
+
*.swp
|
|
32
|
+
*.swo
|
|
33
|
+
.DS_Store
|
|
34
|
+
|
|
35
|
+
# Testing
|
|
36
|
+
.pytest_cache/
|
|
37
|
+
.hypothesis/
|
|
38
|
+
.coverage
|
|
39
|
+
htmlcov/
|
|
40
|
+
.tox/
|
|
41
|
+
.nox/
|
|
42
|
+
|
|
43
|
+
# Type checking
|
|
44
|
+
.mypy_cache/
|
|
45
|
+
|
|
46
|
+
# Local data
|
|
47
|
+
*.db
|
|
48
|
+
*.sqlite
|
|
49
|
+
data/
|
|
50
|
+
|
|
51
|
+
# MkDocs
|
|
52
|
+
site/
|
|
53
|
+
|
|
54
|
+
# Secrets
|
|
55
|
+
.env
|
|
56
|
+
.env.local
|
|
57
|
+
*.pem
|
|
58
|
+
*.key
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: qortex-learning
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Bandit-based adaptive learning for qortex: Thompson Sampling, reward models, persistent state.
|
|
5
|
+
Author: Peleke Sengstacke
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Keywords: adaptive-learning,bandit,knowledge-graph,thompson-sampling
|
|
8
|
+
Classifier: Development Status :: 2 - Pre-Alpha
|
|
9
|
+
Classifier: Intended Audience :: Developers
|
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Requires-Python: >=3.11
|
|
14
|
+
Requires-Dist: aiosqlite>=0.20
|
|
15
|
+
Requires-Dist: qortex-observe
|
|
16
|
+
Provides-Extra: dev
|
|
17
|
+
Requires-Dist: pytest-asyncio>=0.23; extra == 'dev'
|
|
18
|
+
Requires-Dist: pytest>=8.0; extra == 'dev'
|
|
19
|
+
Description-Content-Type: text/markdown
|
|
20
|
+
|
|
21
|
+
# qortex-learning
|
|
22
|
+
|
|
23
|
+
Bandit-based adaptive learning for qortex. Thompson Sampling with Beta-Bernoulli posteriors, persistent state via SQLite, and pluggable reward models.
|
|
24
|
+
|
|
25
|
+
## Usage
|
|
26
|
+
|
|
27
|
+
```python
|
|
28
|
+
from qortex.learning import Learner, LearnerConfig, Arm, ArmOutcome
|
|
29
|
+
|
|
30
|
+
learner = await Learner.create(LearnerConfig(name="prompts"))
|
|
31
|
+
|
|
32
|
+
candidates = [Arm(id="v1", token_cost=10), Arm(id="v2", token_cost=15)]
|
|
33
|
+
result = await learner.select(candidates, context={"task": "type-check"}, k=1)
|
|
34
|
+
|
|
35
|
+
await learner.observe(ArmOutcome(arm_id="v2", outcome="accepted", reward=1.0))
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
Part of the [qortex](https://github.com/Peleke/qortex) workspace.
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# qortex-learning
|
|
2
|
+
|
|
3
|
+
Bandit-based adaptive learning for qortex. Thompson Sampling with Beta-Bernoulli posteriors, persistent state via SQLite, and pluggable reward models.
|
|
4
|
+
|
|
5
|
+
## Usage
|
|
6
|
+
|
|
7
|
+
```python
|
|
8
|
+
from qortex.learning import Learner, LearnerConfig, Arm, ArmOutcome
|
|
9
|
+
|
|
10
|
+
learner = await Learner.create(LearnerConfig(name="prompts"))
|
|
11
|
+
|
|
12
|
+
candidates = [Arm(id="v1", token_cost=10), Arm(id="v2", token_cost=15)]
|
|
13
|
+
result = await learner.select(candidates, context={"task": "type-check"}, k=1)
|
|
14
|
+
|
|
15
|
+
await learner.observe(ArmOutcome(arm_id="v2", outcome="accepted", reward=1.0))
|
|
16
|
+
```
|
|
17
|
+
|
|
18
|
+
Part of the [qortex](https://github.com/Peleke/qortex) workspace.
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "qortex-learning"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Bandit-based adaptive learning for qortex: Thompson Sampling, reward models, persistent state."
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.11"
|
|
7
|
+
license = "MIT"
|
|
8
|
+
authors = [
|
|
9
|
+
{ name = "Peleke Sengstacke" }
|
|
10
|
+
]
|
|
11
|
+
keywords = [
|
|
12
|
+
"bandit",
|
|
13
|
+
"thompson-sampling",
|
|
14
|
+
"adaptive-learning",
|
|
15
|
+
"knowledge-graph",
|
|
16
|
+
]
|
|
17
|
+
classifiers = [
|
|
18
|
+
"Development Status :: 2 - Pre-Alpha",
|
|
19
|
+
"Intended Audience :: Developers",
|
|
20
|
+
"License :: OSI Approved :: MIT License",
|
|
21
|
+
"Programming Language :: Python :: 3.11",
|
|
22
|
+
"Programming Language :: Python :: 3.12",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
dependencies = [
|
|
26
|
+
"aiosqlite>=0.20",
|
|
27
|
+
"qortex-observe",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
[project.optional-dependencies]
|
|
31
|
+
dev = [
|
|
32
|
+
"pytest>=8.0",
|
|
33
|
+
"pytest-asyncio>=0.23",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
[build-system]
|
|
37
|
+
requires = ["hatchling"]
|
|
38
|
+
build-backend = "hatchling.build"
|
|
39
|
+
|
|
40
|
+
[tool.hatch.build.targets.wheel]
|
|
41
|
+
packages = ["src/qortex"]
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""qortex learning module: bandit-based learning for prompt optimization.
|
|
2
|
+
|
|
3
|
+
Public API:
|
|
4
|
+
Learner — Main class: select(), observe(), metrics(), posteriors()
|
|
5
|
+
LearnerConfig — Configuration for a Learner instance
|
|
6
|
+
LearningStore — Protocol for state persistence backends
|
|
7
|
+
JsonLearningStore — JSON file backend
|
|
8
|
+
SqliteLearningStore — SQLite backend (default)
|
|
9
|
+
ThompsonSampling — Beta-Bernoulli Thompson Sampling strategy
|
|
10
|
+
BinaryReward — Binary reward model (accepted=1, else=0)
|
|
11
|
+
TernaryReward — Ternary reward model (accepted=1, partial=0.5, rejected=0)
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from qortex.learning.learner import Learner
|
|
15
|
+
from qortex.learning.reward import BinaryReward, RewardModel, TernaryReward
|
|
16
|
+
from qortex.learning.store import JsonLearningStore, LearningStore, SqliteLearningStore
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
from qortex.learning.pg_store import PostgresLearningStore
|
|
20
|
+
except ImportError:
|
|
21
|
+
PostgresLearningStore = None # type: ignore[assignment,misc]
|
|
22
|
+
from qortex.learning.strategy import LearningStrategy, ThompsonSampling
|
|
23
|
+
from qortex.learning.types import (
|
|
24
|
+
Arm,
|
|
25
|
+
ArmOutcome,
|
|
26
|
+
ArmState,
|
|
27
|
+
LearnerConfig,
|
|
28
|
+
RunTrace,
|
|
29
|
+
SelectionResult,
|
|
30
|
+
context_hash,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"Learner",
|
|
35
|
+
"LearnerConfig",
|
|
36
|
+
"LearningStore",
|
|
37
|
+
"JsonLearningStore",
|
|
38
|
+
"SqliteLearningStore",
|
|
39
|
+
"PostgresLearningStore",
|
|
40
|
+
"LearningStrategy",
|
|
41
|
+
"ThompsonSampling",
|
|
42
|
+
"RewardModel",
|
|
43
|
+
"BinaryReward",
|
|
44
|
+
"TernaryReward",
|
|
45
|
+
"Arm",
|
|
46
|
+
"ArmOutcome",
|
|
47
|
+
"ArmState",
|
|
48
|
+
"SelectionResult",
|
|
49
|
+
"RunTrace",
|
|
50
|
+
"context_hash",
|
|
51
|
+
]
|
|
@@ -0,0 +1,364 @@
|
|
|
1
|
+
"""Learner: the main class for bandit-based learning.
|
|
2
|
+
|
|
3
|
+
Composes strategy + reward model + state store. Exposes select(),
|
|
4
|
+
observe(), batch_observe(), top_arms(), decay_arm(), metrics(),
|
|
5
|
+
and posteriors(). Emits observability events.
|
|
6
|
+
|
|
7
|
+
All I/O methods are async. Use ``await Learner.create(config)`` to
|
|
8
|
+
construct (seed boost requires store I/O).
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import uuid
|
|
14
|
+
from datetime import UTC, datetime
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
from qortex.learning.reward import RewardModel, TernaryReward
|
|
18
|
+
from qortex.learning.store import LearningStore, SqliteLearningStore
|
|
19
|
+
from qortex.learning.strategy import LearningStrategy, ThompsonSampling
|
|
20
|
+
from qortex.learning.types import (
|
|
21
|
+
Arm,
|
|
22
|
+
ArmOutcome,
|
|
23
|
+
ArmState,
|
|
24
|
+
LearnerConfig,
|
|
25
|
+
RunTrace,
|
|
26
|
+
SelectionResult,
|
|
27
|
+
context_hash,
|
|
28
|
+
)
|
|
29
|
+
from qortex.observe import emit
|
|
30
|
+
from qortex.observe.events import (
|
|
31
|
+
LearningObservationRecorded,
|
|
32
|
+
LearningPosteriorUpdated,
|
|
33
|
+
LearningSelectionMade,
|
|
34
|
+
)
|
|
35
|
+
from qortex.observe.tracing import traced
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class Learner:
|
|
39
|
+
"""A bandit learner that selects arms and updates posteriors.
|
|
40
|
+
|
|
41
|
+
Usage::
|
|
42
|
+
|
|
43
|
+
learner = await Learner.create(LearnerConfig(name="prompts"))
|
|
44
|
+
result = await learner.select(candidates, context={"task": "type-errors"}, k=3)
|
|
45
|
+
# ... use selected arms ...
|
|
46
|
+
await learner.observe(ArmOutcome(arm_id="prompt:v2", reward=1.0, outcome="accepted"))
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
config: LearnerConfig,
|
|
52
|
+
strategy: LearningStrategy | None = None,
|
|
53
|
+
reward_model: RewardModel | None = None,
|
|
54
|
+
store: LearningStore | None = None,
|
|
55
|
+
) -> None:
|
|
56
|
+
self.config = config
|
|
57
|
+
self.strategy = strategy or ThompsonSampling()
|
|
58
|
+
self.reward_model = reward_model or TernaryReward()
|
|
59
|
+
self.store = store or SqliteLearningStore(config.name, config.state_dir)
|
|
60
|
+
self._seeded = False
|
|
61
|
+
self._sessions: dict[str, RunTrace] = {}
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
async def create(
|
|
65
|
+
cls,
|
|
66
|
+
config: LearnerConfig,
|
|
67
|
+
strategy: LearningStrategy | None = None,
|
|
68
|
+
reward_model: RewardModel | None = None,
|
|
69
|
+
store: LearningStore | None = None,
|
|
70
|
+
) -> Learner:
|
|
71
|
+
"""Create a Learner and apply seed boosts (requires async store I/O)."""
|
|
72
|
+
learner = cls(config, strategy, reward_model, store)
|
|
73
|
+
await learner._apply_seed_boosts()
|
|
74
|
+
return learner
|
|
75
|
+
|
|
76
|
+
async def _apply_seed_boosts(self) -> None:
|
|
77
|
+
"""Apply seed arm boosts. Idempotent — only boosts arms with zero pulls."""
|
|
78
|
+
if self._seeded:
|
|
79
|
+
return
|
|
80
|
+
self._seeded = True
|
|
81
|
+
for arm_id in self.config.seed_arms:
|
|
82
|
+
state = await self.store.get(arm_id)
|
|
83
|
+
if state.pulls == 0:
|
|
84
|
+
state = ArmState(
|
|
85
|
+
alpha=self.config.seed_boost,
|
|
86
|
+
beta=1.0,
|
|
87
|
+
pulls=0,
|
|
88
|
+
total_reward=0.0,
|
|
89
|
+
last_updated=datetime.now(UTC).isoformat(),
|
|
90
|
+
)
|
|
91
|
+
await self.store.put(arm_id, state)
|
|
92
|
+
if self.config.seed_arms:
|
|
93
|
+
await self.store.save()
|
|
94
|
+
|
|
95
|
+
@traced("learning.select")
|
|
96
|
+
async def select(
|
|
97
|
+
self,
|
|
98
|
+
candidates: list[Arm],
|
|
99
|
+
context: dict | None = None,
|
|
100
|
+
k: int = 1,
|
|
101
|
+
token_budget: int = 0,
|
|
102
|
+
) -> SelectionResult:
|
|
103
|
+
"""Select k arms from candidates."""
|
|
104
|
+
ctx = context or {}
|
|
105
|
+
states = {arm.id: await self.store.get(arm.id, ctx) for arm in candidates}
|
|
106
|
+
|
|
107
|
+
result = self.strategy.select(
|
|
108
|
+
candidates=candidates,
|
|
109
|
+
states=states,
|
|
110
|
+
k=k,
|
|
111
|
+
config=self.config,
|
|
112
|
+
token_budget=token_budget,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
emit(
|
|
116
|
+
LearningSelectionMade(
|
|
117
|
+
learner=self.config.name,
|
|
118
|
+
selected_count=len(result.selected),
|
|
119
|
+
excluded_count=len(result.excluded),
|
|
120
|
+
is_baseline=result.is_baseline,
|
|
121
|
+
token_budget=result.token_budget,
|
|
122
|
+
used_tokens=result.used_tokens,
|
|
123
|
+
)
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
return result
|
|
127
|
+
|
|
128
|
+
@traced("learning.observe")
|
|
129
|
+
async def observe(
|
|
130
|
+
self,
|
|
131
|
+
outcome: ArmOutcome,
|
|
132
|
+
context: dict | None = None,
|
|
133
|
+
) -> ArmState:
|
|
134
|
+
"""Record an observation and update posterior."""
|
|
135
|
+
ctx = context or outcome.context
|
|
136
|
+
reward = outcome.reward
|
|
137
|
+
if outcome.outcome and not outcome.reward:
|
|
138
|
+
reward = self.reward_model.compute(outcome.outcome)
|
|
139
|
+
|
|
140
|
+
state = await self.store.get(outcome.arm_id, ctx)
|
|
141
|
+
now = datetime.now(UTC).isoformat()
|
|
142
|
+
|
|
143
|
+
new_state = self.strategy.update(outcome.arm_id, reward, state)
|
|
144
|
+
new_state = ArmState(
|
|
145
|
+
alpha=new_state.alpha,
|
|
146
|
+
beta=new_state.beta,
|
|
147
|
+
pulls=new_state.pulls,
|
|
148
|
+
total_reward=new_state.total_reward,
|
|
149
|
+
last_updated=now,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
await self.store.put(outcome.arm_id, new_state, ctx)
|
|
153
|
+
await self.store.save()
|
|
154
|
+
|
|
155
|
+
ctx_hash = context_hash(ctx)
|
|
156
|
+
|
|
157
|
+
emit(
|
|
158
|
+
LearningObservationRecorded(
|
|
159
|
+
learner=self.config.name,
|
|
160
|
+
arm_id=outcome.arm_id,
|
|
161
|
+
reward=reward,
|
|
162
|
+
outcome=outcome.outcome,
|
|
163
|
+
context_hash=ctx_hash,
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
emit(
|
|
168
|
+
LearningPosteriorUpdated(
|
|
169
|
+
learner=self.config.name,
|
|
170
|
+
arm_id=outcome.arm_id,
|
|
171
|
+
alpha=new_state.alpha,
|
|
172
|
+
beta=new_state.beta,
|
|
173
|
+
pulls=new_state.pulls,
|
|
174
|
+
mean=new_state.mean,
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
return new_state
|
|
179
|
+
|
|
180
|
+
@traced("learning.apply_credit_deltas")
|
|
181
|
+
async def apply_credit_deltas(
|
|
182
|
+
self,
|
|
183
|
+
deltas: dict[str, dict[str, float]],
|
|
184
|
+
context: dict | None = None,
|
|
185
|
+
) -> dict[str, ArmState]:
|
|
186
|
+
"""Apply causal credit deltas directly to arm posteriors.
|
|
187
|
+
|
|
188
|
+
Unlike observe() which computes alpha/beta from a binary reward,
|
|
189
|
+
this applies arbitrary deltas from CreditAssigner output.
|
|
190
|
+
"""
|
|
191
|
+
ctx = context or {}
|
|
192
|
+
now = datetime.now(UTC).isoformat()
|
|
193
|
+
results: dict[str, ArmState] = {}
|
|
194
|
+
|
|
195
|
+
for arm_id, delta in deltas.items():
|
|
196
|
+
state = await self.store.get(arm_id, ctx)
|
|
197
|
+
new_state = ArmState(
|
|
198
|
+
alpha=max(state.alpha + delta.get("alpha_delta", 0.0), 0.01),
|
|
199
|
+
beta=max(state.beta + delta.get("beta_delta", 0.0), 0.01),
|
|
200
|
+
pulls=state.pulls + 1,
|
|
201
|
+
total_reward=state.total_reward + delta.get("alpha_delta", 0.0),
|
|
202
|
+
last_updated=now,
|
|
203
|
+
)
|
|
204
|
+
await self.store.put(arm_id, new_state, ctx)
|
|
205
|
+
results[arm_id] = new_state
|
|
206
|
+
|
|
207
|
+
emit(
|
|
208
|
+
LearningPosteriorUpdated(
|
|
209
|
+
learner=self.config.name,
|
|
210
|
+
arm_id=arm_id,
|
|
211
|
+
alpha=new_state.alpha,
|
|
212
|
+
beta=new_state.beta,
|
|
213
|
+
pulls=new_state.pulls,
|
|
214
|
+
mean=new_state.mean,
|
|
215
|
+
)
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
await self.store.save()
|
|
219
|
+
return results
|
|
220
|
+
|
|
221
|
+
async def reset(
|
|
222
|
+
self,
|
|
223
|
+
arm_ids: list[str] | None = None,
|
|
224
|
+
context: dict | None = None,
|
|
225
|
+
) -> int:
|
|
226
|
+
"""Delete arm states and return count of entries removed.
|
|
227
|
+
|
|
228
|
+
Useful for resetting poisoned posteriors or clearing stale data.
|
|
229
|
+
"""
|
|
230
|
+
count = await self.store.delete(arm_ids=arm_ids, context=context)
|
|
231
|
+
await self.store.save()
|
|
232
|
+
return count
|
|
233
|
+
|
|
234
|
+
async def batch_observe(
|
|
235
|
+
self,
|
|
236
|
+
outcomes: list[ArmOutcome],
|
|
237
|
+
context: dict | None = None,
|
|
238
|
+
) -> dict[str, ArmState]:
|
|
239
|
+
"""Record multiple observations in a single call.
|
|
240
|
+
|
|
241
|
+
Delegates to observe() for each outcome so events are emitted
|
|
242
|
+
per-arm. Saves once at the end (observe already saves per call,
|
|
243
|
+
but this keeps the contract simple for callers).
|
|
244
|
+
"""
|
|
245
|
+
results: dict[str, ArmState] = {}
|
|
246
|
+
for outcome in outcomes:
|
|
247
|
+
results[outcome.arm_id] = await self.observe(outcome, context)
|
|
248
|
+
return results
|
|
249
|
+
|
|
250
|
+
async def top_arms(
|
|
251
|
+
self,
|
|
252
|
+
context: dict | None = None,
|
|
253
|
+
k: int = 10,
|
|
254
|
+
) -> list[tuple[str, ArmState]]:
|
|
255
|
+
"""Return the top-k arms by posterior mean, descending.
|
|
256
|
+
|
|
257
|
+
Derives from posteriors(). Returns (arm_id, ArmState) tuples
|
|
258
|
+
so callers get both the ID and the full state.
|
|
259
|
+
"""
|
|
260
|
+
all_states = await self.store.get_all(context)
|
|
261
|
+
sorted_arms = sorted(
|
|
262
|
+
all_states.items(),
|
|
263
|
+
key=lambda pair: pair[1].mean,
|
|
264
|
+
reverse=True,
|
|
265
|
+
)
|
|
266
|
+
return sorted_arms[:k]
|
|
267
|
+
|
|
268
|
+
async def decay_arm(
|
|
269
|
+
self,
|
|
270
|
+
arm_id: str,
|
|
271
|
+
decay_factor: float = 0.9,
|
|
272
|
+
context: dict | None = None,
|
|
273
|
+
) -> ArmState:
|
|
274
|
+
"""Shrink an arm's learned signal toward the prior.
|
|
275
|
+
|
|
276
|
+
Multiplies alpha and beta by decay_factor, weakening confidence
|
|
277
|
+
while preserving the mean ratio. Useful when seed data changes
|
|
278
|
+
and old signal should fade.
|
|
279
|
+
|
|
280
|
+
Floors alpha/beta at 0.01 to avoid degenerate posteriors.
|
|
281
|
+
"""
|
|
282
|
+
state = await self.store.get(arm_id, context)
|
|
283
|
+
now = datetime.now(UTC).isoformat()
|
|
284
|
+
new_state = ArmState(
|
|
285
|
+
alpha=max(state.alpha * decay_factor, 0.01),
|
|
286
|
+
beta=max(state.beta * decay_factor, 0.01),
|
|
287
|
+
pulls=state.pulls,
|
|
288
|
+
total_reward=state.total_reward * decay_factor,
|
|
289
|
+
last_updated=now,
|
|
290
|
+
)
|
|
291
|
+
await self.store.put(arm_id, new_state, context)
|
|
292
|
+
await self.store.save()
|
|
293
|
+
return new_state
|
|
294
|
+
|
|
295
|
+
async def posteriors(
|
|
296
|
+
self,
|
|
297
|
+
context: dict | None = None,
|
|
298
|
+
arm_ids: list[str] | None = None,
|
|
299
|
+
) -> dict[str, dict[str, Any]]:
|
|
300
|
+
"""Get current posteriors for arms."""
|
|
301
|
+
all_states = await self.store.get_all(context)
|
|
302
|
+
|
|
303
|
+
if arm_ids is not None:
|
|
304
|
+
all_states = {k: v for k, v in all_states.items() if k in set(arm_ids)}
|
|
305
|
+
|
|
306
|
+
return {
|
|
307
|
+
arm_id: {
|
|
308
|
+
**state.to_dict(),
|
|
309
|
+
"mean": state.mean,
|
|
310
|
+
}
|
|
311
|
+
for arm_id, state in all_states.items()
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
async def metrics(self, window: int | None = None) -> dict[str, Any]:
|
|
315
|
+
"""Compute learning metrics across all contexts."""
|
|
316
|
+
total_pulls = 0
|
|
317
|
+
total_reward = 0.0
|
|
318
|
+
arm_count = 0
|
|
319
|
+
|
|
320
|
+
all_states = await self.store.get_all_states()
|
|
321
|
+
for states in all_states.values():
|
|
322
|
+
for state in states.values():
|
|
323
|
+
total_pulls += state.pulls
|
|
324
|
+
total_reward += state.total_reward
|
|
325
|
+
arm_count += 1
|
|
326
|
+
|
|
327
|
+
accuracy = total_reward / max(total_pulls, 1)
|
|
328
|
+
explore_ratio = self.config.baseline_rate
|
|
329
|
+
|
|
330
|
+
return {
|
|
331
|
+
"learner": self.config.name,
|
|
332
|
+
"total_pulls": total_pulls,
|
|
333
|
+
"total_reward": total_reward,
|
|
334
|
+
"accuracy": round(accuracy, 4),
|
|
335
|
+
"arm_count": arm_count,
|
|
336
|
+
"explore_ratio": explore_ratio,
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
def session_start(self, session_name: str) -> str:
|
|
340
|
+
"""Start a named learning session for tracking."""
|
|
341
|
+
session_id = str(uuid.uuid4())
|
|
342
|
+
self._sessions[session_id] = RunTrace(
|
|
343
|
+
session_id=session_id,
|
|
344
|
+
learner=self.config.name,
|
|
345
|
+
selected_arms=[],
|
|
346
|
+
started_at=datetime.now(UTC).isoformat(),
|
|
347
|
+
)
|
|
348
|
+
return session_id
|
|
349
|
+
|
|
350
|
+
def session_end(self, session_id: str) -> dict[str, Any]:
|
|
351
|
+
"""End a session and return summary."""
|
|
352
|
+
trace = self._sessions.pop(session_id, None)
|
|
353
|
+
if trace is None:
|
|
354
|
+
return {"error": f"Session {session_id} not found"}
|
|
355
|
+
|
|
356
|
+
trace.ended_at = datetime.now(UTC).isoformat()
|
|
357
|
+
return {
|
|
358
|
+
"session_id": session_id,
|
|
359
|
+
"learner": trace.learner,
|
|
360
|
+
"selected_arms": trace.selected_arms,
|
|
361
|
+
"outcomes": trace.outcomes,
|
|
362
|
+
"started_at": trace.started_at,
|
|
363
|
+
"ended_at": trace.ended_at,
|
|
364
|
+
}
|