qortex-learning 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,58 @@
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ .venv/
25
+ venv/
26
+ ENV/
27
+
28
+ # IDE
29
+ .idea/
30
+ .vscode/
31
+ *.swp
32
+ *.swo
33
+ .DS_Store
34
+
35
+ # Testing
36
+ .pytest_cache/
37
+ .hypothesis/
38
+ .coverage
39
+ htmlcov/
40
+ .tox/
41
+ .nox/
42
+
43
+ # Type checking
44
+ .mypy_cache/
45
+
46
+ # Local data
47
+ *.db
48
+ *.sqlite
49
+ data/
50
+
51
+ # MkDocs
52
+ site/
53
+
54
+ # Secrets
55
+ .env
56
+ .env.local
57
+ *.pem
58
+ *.key
@@ -0,0 +1,38 @@
1
+ Metadata-Version: 2.4
2
+ Name: qortex-learning
3
+ Version: 0.1.0
4
+ Summary: Bandit-based adaptive learning for qortex: Thompson Sampling, reward models, persistent state.
5
+ Author: Peleke Sengstacke
6
+ License-Expression: MIT
7
+ Keywords: adaptive-learning,bandit,knowledge-graph,thompson-sampling
8
+ Classifier: Development Status :: 2 - Pre-Alpha
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Requires-Python: >=3.11
14
+ Requires-Dist: aiosqlite>=0.20
15
+ Requires-Dist: qortex-observe
16
+ Provides-Extra: dev
17
+ Requires-Dist: pytest-asyncio>=0.23; extra == 'dev'
18
+ Requires-Dist: pytest>=8.0; extra == 'dev'
19
+ Description-Content-Type: text/markdown
20
+
21
+ # qortex-learning
22
+
23
+ Bandit-based adaptive learning for qortex. Thompson Sampling with Beta-Bernoulli posteriors, persistent state via SQLite, and pluggable reward models.
24
+
25
+ ## Usage
26
+
27
+ ```python
28
+ from qortex.learning import Learner, LearnerConfig, Arm, ArmOutcome
29
+
30
+ learner = await Learner.create(LearnerConfig(name="prompts"))
31
+
32
+ candidates = [Arm(id="v1", token_cost=10), Arm(id="v2", token_cost=15)]
33
+ result = await learner.select(candidates, context={"task": "type-check"}, k=1)
34
+
35
+ await learner.observe(ArmOutcome(arm_id="v2", outcome="accepted", reward=1.0))
36
+ ```
37
+
38
+ Part of the [qortex](https://github.com/Peleke/qortex) workspace.
@@ -0,0 +1,18 @@
1
+ # qortex-learning
2
+
3
+ Bandit-based adaptive learning for qortex. Thompson Sampling with Beta-Bernoulli posteriors, persistent state via SQLite, and pluggable reward models.
4
+
5
+ ## Usage
6
+
7
+ ```python
8
+ from qortex.learning import Learner, LearnerConfig, Arm, ArmOutcome
9
+
10
+ learner = await Learner.create(LearnerConfig(name="prompts"))
11
+
12
+ candidates = [Arm(id="v1", token_cost=10), Arm(id="v2", token_cost=15)]
13
+ result = await learner.select(candidates, context={"task": "type-check"}, k=1)
14
+
15
+ await learner.observe(ArmOutcome(arm_id="v2", outcome="accepted", reward=1.0))
16
+ ```
17
+
18
+ Part of the [qortex](https://github.com/Peleke/qortex) workspace.
@@ -0,0 +1,41 @@
1
+ [project]
2
+ name = "qortex-learning"
3
+ version = "0.1.0"
4
+ description = "Bandit-based adaptive learning for qortex: Thompson Sampling, reward models, persistent state."
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ license = "MIT"
8
+ authors = [
9
+ { name = "Peleke Sengstacke" }
10
+ ]
11
+ keywords = [
12
+ "bandit",
13
+ "thompson-sampling",
14
+ "adaptive-learning",
15
+ "knowledge-graph",
16
+ ]
17
+ classifiers = [
18
+ "Development Status :: 2 - Pre-Alpha",
19
+ "Intended Audience :: Developers",
20
+ "License :: OSI Approved :: MIT License",
21
+ "Programming Language :: Python :: 3.11",
22
+ "Programming Language :: Python :: 3.12",
23
+ ]
24
+
25
+ dependencies = [
26
+ "aiosqlite>=0.20",
27
+ "qortex-observe",
28
+ ]
29
+
30
+ [project.optional-dependencies]
31
+ dev = [
32
+ "pytest>=8.0",
33
+ "pytest-asyncio>=0.23",
34
+ ]
35
+
36
+ [build-system]
37
+ requires = ["hatchling"]
38
+ build-backend = "hatchling.build"
39
+
40
+ [tool.hatch.build.targets.wheel]
41
+ packages = ["src/qortex"]
@@ -0,0 +1,51 @@
1
+ """qortex learning module: bandit-based learning for prompt optimization.
2
+
3
+ Public API:
4
+ Learner — Main class: select(), observe(), metrics(), posteriors()
5
+ LearnerConfig — Configuration for a Learner instance
6
+ LearningStore — Protocol for state persistence backends
7
+ JsonLearningStore — JSON file backend
8
+ SqliteLearningStore — SQLite backend (default)
9
+ ThompsonSampling — Beta-Bernoulli Thompson Sampling strategy
10
+ BinaryReward — Binary reward model (accepted=1, else=0)
11
+ TernaryReward — Ternary reward model (accepted=1, partial=0.5, rejected=0)
12
+ """
13
+
14
+ from qortex.learning.learner import Learner
15
+ from qortex.learning.reward import BinaryReward, RewardModel, TernaryReward
16
+ from qortex.learning.store import JsonLearningStore, LearningStore, SqliteLearningStore
17
+
18
+ try:
19
+ from qortex.learning.pg_store import PostgresLearningStore
20
+ except ImportError:
21
+ PostgresLearningStore = None # type: ignore[assignment,misc]
22
+ from qortex.learning.strategy import LearningStrategy, ThompsonSampling
23
+ from qortex.learning.types import (
24
+ Arm,
25
+ ArmOutcome,
26
+ ArmState,
27
+ LearnerConfig,
28
+ RunTrace,
29
+ SelectionResult,
30
+ context_hash,
31
+ )
32
+
33
+ __all__ = [
34
+ "Learner",
35
+ "LearnerConfig",
36
+ "LearningStore",
37
+ "JsonLearningStore",
38
+ "SqliteLearningStore",
39
+ "PostgresLearningStore",
40
+ "LearningStrategy",
41
+ "ThompsonSampling",
42
+ "RewardModel",
43
+ "BinaryReward",
44
+ "TernaryReward",
45
+ "Arm",
46
+ "ArmOutcome",
47
+ "ArmState",
48
+ "SelectionResult",
49
+ "RunTrace",
50
+ "context_hash",
51
+ ]
@@ -0,0 +1,364 @@
1
+ """Learner: the main class for bandit-based learning.
2
+
3
+ Composes strategy + reward model + state store. Exposes select(),
4
+ observe(), batch_observe(), top_arms(), decay_arm(), metrics(),
5
+ and posteriors(). Emits observability events.
6
+
7
+ All I/O methods are async. Use ``await Learner.create(config)`` to
8
+ construct (seed boost requires store I/O).
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import uuid
14
+ from datetime import UTC, datetime
15
+ from typing import Any
16
+
17
+ from qortex.learning.reward import RewardModel, TernaryReward
18
+ from qortex.learning.store import LearningStore, SqliteLearningStore
19
+ from qortex.learning.strategy import LearningStrategy, ThompsonSampling
20
+ from qortex.learning.types import (
21
+ Arm,
22
+ ArmOutcome,
23
+ ArmState,
24
+ LearnerConfig,
25
+ RunTrace,
26
+ SelectionResult,
27
+ context_hash,
28
+ )
29
+ from qortex.observe import emit
30
+ from qortex.observe.events import (
31
+ LearningObservationRecorded,
32
+ LearningPosteriorUpdated,
33
+ LearningSelectionMade,
34
+ )
35
+ from qortex.observe.tracing import traced
36
+
37
+
38
+ class Learner:
39
+ """A bandit learner that selects arms and updates posteriors.
40
+
41
+ Usage::
42
+
43
+ learner = await Learner.create(LearnerConfig(name="prompts"))
44
+ result = await learner.select(candidates, context={"task": "type-errors"}, k=3)
45
+ # ... use selected arms ...
46
+ await learner.observe(ArmOutcome(arm_id="prompt:v2", reward=1.0, outcome="accepted"))
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ config: LearnerConfig,
52
+ strategy: LearningStrategy | None = None,
53
+ reward_model: RewardModel | None = None,
54
+ store: LearningStore | None = None,
55
+ ) -> None:
56
+ self.config = config
57
+ self.strategy = strategy or ThompsonSampling()
58
+ self.reward_model = reward_model or TernaryReward()
59
+ self.store = store or SqliteLearningStore(config.name, config.state_dir)
60
+ self._seeded = False
61
+ self._sessions: dict[str, RunTrace] = {}
62
+
63
+ @classmethod
64
+ async def create(
65
+ cls,
66
+ config: LearnerConfig,
67
+ strategy: LearningStrategy | None = None,
68
+ reward_model: RewardModel | None = None,
69
+ store: LearningStore | None = None,
70
+ ) -> Learner:
71
+ """Create a Learner and apply seed boosts (requires async store I/O)."""
72
+ learner = cls(config, strategy, reward_model, store)
73
+ await learner._apply_seed_boosts()
74
+ return learner
75
+
76
+ async def _apply_seed_boosts(self) -> None:
77
+ """Apply seed arm boosts. Idempotent — only boosts arms with zero pulls."""
78
+ if self._seeded:
79
+ return
80
+ self._seeded = True
81
+ for arm_id in self.config.seed_arms:
82
+ state = await self.store.get(arm_id)
83
+ if state.pulls == 0:
84
+ state = ArmState(
85
+ alpha=self.config.seed_boost,
86
+ beta=1.0,
87
+ pulls=0,
88
+ total_reward=0.0,
89
+ last_updated=datetime.now(UTC).isoformat(),
90
+ )
91
+ await self.store.put(arm_id, state)
92
+ if self.config.seed_arms:
93
+ await self.store.save()
94
+
95
+ @traced("learning.select")
96
+ async def select(
97
+ self,
98
+ candidates: list[Arm],
99
+ context: dict | None = None,
100
+ k: int = 1,
101
+ token_budget: int = 0,
102
+ ) -> SelectionResult:
103
+ """Select k arms from candidates."""
104
+ ctx = context or {}
105
+ states = {arm.id: await self.store.get(arm.id, ctx) for arm in candidates}
106
+
107
+ result = self.strategy.select(
108
+ candidates=candidates,
109
+ states=states,
110
+ k=k,
111
+ config=self.config,
112
+ token_budget=token_budget,
113
+ )
114
+
115
+ emit(
116
+ LearningSelectionMade(
117
+ learner=self.config.name,
118
+ selected_count=len(result.selected),
119
+ excluded_count=len(result.excluded),
120
+ is_baseline=result.is_baseline,
121
+ token_budget=result.token_budget,
122
+ used_tokens=result.used_tokens,
123
+ )
124
+ )
125
+
126
+ return result
127
+
128
+ @traced("learning.observe")
129
+ async def observe(
130
+ self,
131
+ outcome: ArmOutcome,
132
+ context: dict | None = None,
133
+ ) -> ArmState:
134
+ """Record an observation and update posterior."""
135
+ ctx = context or outcome.context
136
+ reward = outcome.reward
137
+ if outcome.outcome and not outcome.reward:
138
+ reward = self.reward_model.compute(outcome.outcome)
139
+
140
+ state = await self.store.get(outcome.arm_id, ctx)
141
+ now = datetime.now(UTC).isoformat()
142
+
143
+ new_state = self.strategy.update(outcome.arm_id, reward, state)
144
+ new_state = ArmState(
145
+ alpha=new_state.alpha,
146
+ beta=new_state.beta,
147
+ pulls=new_state.pulls,
148
+ total_reward=new_state.total_reward,
149
+ last_updated=now,
150
+ )
151
+
152
+ await self.store.put(outcome.arm_id, new_state, ctx)
153
+ await self.store.save()
154
+
155
+ ctx_hash = context_hash(ctx)
156
+
157
+ emit(
158
+ LearningObservationRecorded(
159
+ learner=self.config.name,
160
+ arm_id=outcome.arm_id,
161
+ reward=reward,
162
+ outcome=outcome.outcome,
163
+ context_hash=ctx_hash,
164
+ )
165
+ )
166
+
167
+ emit(
168
+ LearningPosteriorUpdated(
169
+ learner=self.config.name,
170
+ arm_id=outcome.arm_id,
171
+ alpha=new_state.alpha,
172
+ beta=new_state.beta,
173
+ pulls=new_state.pulls,
174
+ mean=new_state.mean,
175
+ )
176
+ )
177
+
178
+ return new_state
179
+
180
+ @traced("learning.apply_credit_deltas")
181
+ async def apply_credit_deltas(
182
+ self,
183
+ deltas: dict[str, dict[str, float]],
184
+ context: dict | None = None,
185
+ ) -> dict[str, ArmState]:
186
+ """Apply causal credit deltas directly to arm posteriors.
187
+
188
+ Unlike observe() which computes alpha/beta from a binary reward,
189
+ this applies arbitrary deltas from CreditAssigner output.
190
+ """
191
+ ctx = context or {}
192
+ now = datetime.now(UTC).isoformat()
193
+ results: dict[str, ArmState] = {}
194
+
195
+ for arm_id, delta in deltas.items():
196
+ state = await self.store.get(arm_id, ctx)
197
+ new_state = ArmState(
198
+ alpha=max(state.alpha + delta.get("alpha_delta", 0.0), 0.01),
199
+ beta=max(state.beta + delta.get("beta_delta", 0.0), 0.01),
200
+ pulls=state.pulls + 1,
201
+ total_reward=state.total_reward + delta.get("alpha_delta", 0.0),
202
+ last_updated=now,
203
+ )
204
+ await self.store.put(arm_id, new_state, ctx)
205
+ results[arm_id] = new_state
206
+
207
+ emit(
208
+ LearningPosteriorUpdated(
209
+ learner=self.config.name,
210
+ arm_id=arm_id,
211
+ alpha=new_state.alpha,
212
+ beta=new_state.beta,
213
+ pulls=new_state.pulls,
214
+ mean=new_state.mean,
215
+ )
216
+ )
217
+
218
+ await self.store.save()
219
+ return results
220
+
221
+ async def reset(
222
+ self,
223
+ arm_ids: list[str] | None = None,
224
+ context: dict | None = None,
225
+ ) -> int:
226
+ """Delete arm states and return count of entries removed.
227
+
228
+ Useful for resetting poisoned posteriors or clearing stale data.
229
+ """
230
+ count = await self.store.delete(arm_ids=arm_ids, context=context)
231
+ await self.store.save()
232
+ return count
233
+
234
+ async def batch_observe(
235
+ self,
236
+ outcomes: list[ArmOutcome],
237
+ context: dict | None = None,
238
+ ) -> dict[str, ArmState]:
239
+ """Record multiple observations in a single call.
240
+
241
+ Delegates to observe() for each outcome so events are emitted
242
+ per-arm. Saves once at the end (observe already saves per call,
243
+ but this keeps the contract simple for callers).
244
+ """
245
+ results: dict[str, ArmState] = {}
246
+ for outcome in outcomes:
247
+ results[outcome.arm_id] = await self.observe(outcome, context)
248
+ return results
249
+
250
+ async def top_arms(
251
+ self,
252
+ context: dict | None = None,
253
+ k: int = 10,
254
+ ) -> list[tuple[str, ArmState]]:
255
+ """Return the top-k arms by posterior mean, descending.
256
+
257
+ Derives from posteriors(). Returns (arm_id, ArmState) tuples
258
+ so callers get both the ID and the full state.
259
+ """
260
+ all_states = await self.store.get_all(context)
261
+ sorted_arms = sorted(
262
+ all_states.items(),
263
+ key=lambda pair: pair[1].mean,
264
+ reverse=True,
265
+ )
266
+ return sorted_arms[:k]
267
+
268
+ async def decay_arm(
269
+ self,
270
+ arm_id: str,
271
+ decay_factor: float = 0.9,
272
+ context: dict | None = None,
273
+ ) -> ArmState:
274
+ """Shrink an arm's learned signal toward the prior.
275
+
276
+ Multiplies alpha and beta by decay_factor, weakening confidence
277
+ while preserving the mean ratio. Useful when seed data changes
278
+ and old signal should fade.
279
+
280
+ Floors alpha/beta at 0.01 to avoid degenerate posteriors.
281
+ """
282
+ state = await self.store.get(arm_id, context)
283
+ now = datetime.now(UTC).isoformat()
284
+ new_state = ArmState(
285
+ alpha=max(state.alpha * decay_factor, 0.01),
286
+ beta=max(state.beta * decay_factor, 0.01),
287
+ pulls=state.pulls,
288
+ total_reward=state.total_reward * decay_factor,
289
+ last_updated=now,
290
+ )
291
+ await self.store.put(arm_id, new_state, context)
292
+ await self.store.save()
293
+ return new_state
294
+
295
+ async def posteriors(
296
+ self,
297
+ context: dict | None = None,
298
+ arm_ids: list[str] | None = None,
299
+ ) -> dict[str, dict[str, Any]]:
300
+ """Get current posteriors for arms."""
301
+ all_states = await self.store.get_all(context)
302
+
303
+ if arm_ids is not None:
304
+ all_states = {k: v for k, v in all_states.items() if k in set(arm_ids)}
305
+
306
+ return {
307
+ arm_id: {
308
+ **state.to_dict(),
309
+ "mean": state.mean,
310
+ }
311
+ for arm_id, state in all_states.items()
312
+ }
313
+
314
+ async def metrics(self, window: int | None = None) -> dict[str, Any]:
315
+ """Compute learning metrics across all contexts."""
316
+ total_pulls = 0
317
+ total_reward = 0.0
318
+ arm_count = 0
319
+
320
+ all_states = await self.store.get_all_states()
321
+ for states in all_states.values():
322
+ for state in states.values():
323
+ total_pulls += state.pulls
324
+ total_reward += state.total_reward
325
+ arm_count += 1
326
+
327
+ accuracy = total_reward / max(total_pulls, 1)
328
+ explore_ratio = self.config.baseline_rate
329
+
330
+ return {
331
+ "learner": self.config.name,
332
+ "total_pulls": total_pulls,
333
+ "total_reward": total_reward,
334
+ "accuracy": round(accuracy, 4),
335
+ "arm_count": arm_count,
336
+ "explore_ratio": explore_ratio,
337
+ }
338
+
339
+ def session_start(self, session_name: str) -> str:
340
+ """Start a named learning session for tracking."""
341
+ session_id = str(uuid.uuid4())
342
+ self._sessions[session_id] = RunTrace(
343
+ session_id=session_id,
344
+ learner=self.config.name,
345
+ selected_arms=[],
346
+ started_at=datetime.now(UTC).isoformat(),
347
+ )
348
+ return session_id
349
+
350
+ def session_end(self, session_id: str) -> dict[str, Any]:
351
+ """End a session and return summary."""
352
+ trace = self._sessions.pop(session_id, None)
353
+ if trace is None:
354
+ return {"error": f"Session {session_id} not found"}
355
+
356
+ trace.ended_at = datetime.now(UTC).isoformat()
357
+ return {
358
+ "session_id": session_id,
359
+ "learner": trace.learner,
360
+ "selected_arms": trace.selected_arms,
361
+ "outcomes": trace.outcomes,
362
+ "started_at": trace.started_at,
363
+ "ended_at": trace.ended_at,
364
+ }