rlhf-feedback-loop 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +26 -0
- package/LICENSE +21 -0
- package/README.md +308 -0
- package/adapters/README.md +8 -0
- package/adapters/amp/skills/rlhf-feedback/SKILL.md +20 -0
- package/adapters/chatgpt/INSTALL.md +80 -0
- package/adapters/chatgpt/openapi.yaml +292 -0
- package/adapters/claude/.mcp.json +8 -0
- package/adapters/codex/config.toml +4 -0
- package/adapters/gemini/function-declarations.json +95 -0
- package/adapters/mcp/server-stdio.js +444 -0
- package/bin/cli.js +167 -0
- package/config/mcp-allowlists.json +29 -0
- package/config/policy-bundles/constrained-v1.json +53 -0
- package/config/policy-bundles/default-v1.json +80 -0
- package/config/rubrics/default-v1.json +52 -0
- package/config/subagent-profiles.json +32 -0
- package/openapi/openapi.yaml +292 -0
- package/package.json +91 -0
- package/plugins/amp-skill/INSTALL.md +52 -0
- package/plugins/amp-skill/SKILL.md +31 -0
- package/plugins/claude-skill/INSTALL.md +55 -0
- package/plugins/claude-skill/SKILL.md +46 -0
- package/plugins/codex-profile/AGENTS.md +20 -0
- package/plugins/codex-profile/INSTALL.md +57 -0
- package/plugins/gemini-extension/INSTALL.md +74 -0
- package/plugins/gemini-extension/gemini_prompt.txt +10 -0
- package/plugins/gemini-extension/tool_contract.json +28 -0
- package/scripts/billing.js +471 -0
- package/scripts/budget-guard.js +173 -0
- package/scripts/code-reasoning.js +307 -0
- package/scripts/context-engine.js +547 -0
- package/scripts/contextfs.js +513 -0
- package/scripts/contract-audit.js +198 -0
- package/scripts/dpo-optimizer.js +208 -0
- package/scripts/export-dpo-pairs.js +316 -0
- package/scripts/export-training.js +448 -0
- package/scripts/feedback-attribution.js +313 -0
- package/scripts/feedback-inbox-read.js +162 -0
- package/scripts/feedback-loop.js +838 -0
- package/scripts/feedback-schema.js +300 -0
- package/scripts/feedback-to-memory.js +165 -0
- package/scripts/feedback-to-rules.js +109 -0
- package/scripts/generate-paperbanana-diagrams.sh +99 -0
- package/scripts/hybrid-feedback-context.js +676 -0
- package/scripts/intent-router.js +164 -0
- package/scripts/mcp-policy.js +92 -0
- package/scripts/meta-policy.js +194 -0
- package/scripts/plan-gate.js +154 -0
- package/scripts/prove-adapters.js +364 -0
- package/scripts/prove-attribution.js +364 -0
- package/scripts/prove-automation.js +393 -0
- package/scripts/prove-data-quality.js +219 -0
- package/scripts/prove-intelligence.js +256 -0
- package/scripts/prove-lancedb.js +370 -0
- package/scripts/prove-loop-closure.js +255 -0
- package/scripts/prove-rlaif.js +404 -0
- package/scripts/prove-subway-upgrades.js +250 -0
- package/scripts/prove-training-export.js +324 -0
- package/scripts/prove-v2-milestone.js +273 -0
- package/scripts/prove-v3-milestone.js +381 -0
- package/scripts/rlaif-self-audit.js +123 -0
- package/scripts/rubric-engine.js +230 -0
- package/scripts/self-heal.js +127 -0
- package/scripts/self-healing-check.js +111 -0
- package/scripts/skill-quality-tracker.js +284 -0
- package/scripts/subagent-profiles.js +79 -0
- package/scripts/sync-gh-secrets-from-env.sh +29 -0
- package/scripts/thompson-sampling.js +331 -0
- package/scripts/train_from_feedback.py +914 -0
- package/scripts/validate-feedback.js +580 -0
- package/scripts/vector-store.js +100 -0
- package/src/api/server.js +497 -0
|
@@ -0,0 +1,914 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Thompson Sampling Feedback Model Trainer
|
|
4
|
+
|
|
5
|
+
Beta-Bernoulli Thompson Sampling for per-category reliability estimation.
|
|
6
|
+
Reads from feedback-log.jsonl and builds a Bayesian model of Claude's
|
|
7
|
+
performance across different task categories.
|
|
8
|
+
|
|
9
|
+
Usage:
|
|
10
|
+
python train_from_feedback.py --train # Full rebuild from JSONL
|
|
11
|
+
python train_from_feedback.py --incremental # Update with latest entry
|
|
12
|
+
python train_from_feedback.py --reliability # Print reliability table
|
|
13
|
+
python train_from_feedback.py --sample # Sample from posteriors
|
|
14
|
+
python train_from_feedback.py --snapshot # Save model snapshot
|
|
15
|
+
python train_from_feedback.py --dpo-train # DPO batch optimization (Feb 2026)
|
|
16
|
+
python train_from_feedback.py --config config.json # Use custom categories
|
|
17
|
+
|
|
18
|
+
LOCAL ONLY - Do not commit to repository
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
import sys
|
|
22
|
+
import json
|
|
23
|
+
import math
|
|
24
|
+
import random
|
|
25
|
+
import argparse
|
|
26
|
+
from datetime import datetime, timedelta
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
from typing import Dict, List, Any, Optional, Tuple
|
|
29
|
+
|
|
30
|
+
# Configuration
|
|
31
|
+
# rlhf path: scripts/train_from_feedback.py → scripts/ → rlhf/
|
|
32
|
+
SCRIPT_DIR = Path(__file__).parent
|
|
33
|
+
PROJECT_ROOT = Path(__file__).parent.parent
|
|
34
|
+
FEEDBACK_LOG = PROJECT_ROOT / ".claude" / "memory" / "feedback" / "feedback-log.jsonl"
|
|
35
|
+
MODEL_FILE = PROJECT_ROOT / ".claude" / "memory" / "feedback" / "feedback_model.json"
|
|
36
|
+
SNAPSHOTS_DIR = PROJECT_ROOT / ".claude" / "memory" / "feedback" / "model_snapshots"
|
|
37
|
+
|
|
38
|
+
# Default categories (overridden by --config)
|
|
39
|
+
DEFAULT_CATEGORIES = {
|
|
40
|
+
"code_edit": {
|
|
41
|
+
"keywords": ["edit", "write", "implement", "refactor", "fix", "update", "create file"],
|
|
42
|
+
"tools": ["Edit", "Write", "MultiEdit"],
|
|
43
|
+
},
|
|
44
|
+
"git": {
|
|
45
|
+
"keywords": ["commit", "push", "branch", "merge", "pr", "pull request", "rebase", "cherry-pick"],
|
|
46
|
+
"tools": ["Bash"],
|
|
47
|
+
},
|
|
48
|
+
"testing": {
|
|
49
|
+
"keywords": ["test", "jest", "coverage", "reassure", "perf", "spec", "mock", "assert"],
|
|
50
|
+
"tools": [],
|
|
51
|
+
},
|
|
52
|
+
"pr_review": {
|
|
53
|
+
"keywords": ["review", "pr comment", "resolve", "minimize", "thread", "feedback"],
|
|
54
|
+
"tools": [],
|
|
55
|
+
},
|
|
56
|
+
"search": {
|
|
57
|
+
"keywords": ["search", "find", "grep", "glob", "explore", "where is", "look for"],
|
|
58
|
+
"tools": ["Grep", "Glob", "Read"],
|
|
59
|
+
},
|
|
60
|
+
"architecture": {
|
|
61
|
+
"keywords": ["architecture", "design", "pattern", "structure", "fsd", "module", "navigation"],
|
|
62
|
+
"tools": [],
|
|
63
|
+
},
|
|
64
|
+
"security": {
|
|
65
|
+
"keywords": ["security", "secret", "vulnerability", "injection", "xss", "owasp", "trufflehog"],
|
|
66
|
+
"tools": [],
|
|
67
|
+
},
|
|
68
|
+
"debugging": {
|
|
69
|
+
"keywords": ["debug", "error", "crash", "stack trace", "log", "diagnose", "investigate"],
|
|
70
|
+
"tools": [],
|
|
71
|
+
},
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
# Time decay configuration (2026 upgrade: exponential decay with half-life)
|
|
75
|
+
# Step decay (legacy)
|
|
76
|
+
DECAY_WEIGHTS = {
|
|
77
|
+
7: 1.0, # < 7 days: full weight
|
|
78
|
+
30: 0.5, # 7-30 days: half weight
|
|
79
|
+
None: 0.25 # > 30 days: quarter weight
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
# Exponential decay (2026 best practice)
|
|
83
|
+
# Half-life of 7 days: feedback loses half its weight every 7 days
|
|
84
|
+
HALF_LIFE_DAYS = 7.0
|
|
85
|
+
USE_EXPONENTIAL_DECAY = True # Toggle between step and exponential
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def load_config(config_path: Optional[str]) -> Dict:
|
|
89
|
+
"""Load category configuration from file or use defaults."""
|
|
90
|
+
if config_path:
|
|
91
|
+
path = Path(config_path)
|
|
92
|
+
if path.exists():
|
|
93
|
+
return json.loads(path.read_text())
|
|
94
|
+
return DEFAULT_CATEGORIES
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def load_model() -> Dict:
|
|
98
|
+
"""Load existing model or create with uniform priors."""
|
|
99
|
+
if MODEL_FILE.exists():
|
|
100
|
+
try:
|
|
101
|
+
return json.loads(MODEL_FILE.read_text())
|
|
102
|
+
except json.JSONDecodeError:
|
|
103
|
+
pass
|
|
104
|
+
return create_initial_model(DEFAULT_CATEGORIES)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def create_initial_model(categories: Dict) -> Dict:
|
|
108
|
+
"""Create model with uniform Beta(1,1) priors for all categories."""
|
|
109
|
+
model = {
|
|
110
|
+
"version": 1,
|
|
111
|
+
"created": datetime.now().isoformat(),
|
|
112
|
+
"updated": datetime.now().isoformat(),
|
|
113
|
+
"total_entries": 0,
|
|
114
|
+
"categories": {},
|
|
115
|
+
}
|
|
116
|
+
for cat_name in categories:
|
|
117
|
+
model["categories"][cat_name] = {
|
|
118
|
+
"alpha": 1.0, # Prior successes + 1
|
|
119
|
+
"beta": 1.0, # Prior failures + 1
|
|
120
|
+
"samples": 0,
|
|
121
|
+
"last_updated": None,
|
|
122
|
+
}
|
|
123
|
+
return model
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def save_model(model: Dict):
|
|
127
|
+
"""Save model to disk."""
|
|
128
|
+
MODEL_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
129
|
+
model["updated"] = datetime.now().isoformat()
|
|
130
|
+
MODEL_FILE.write_text(json.dumps(model, indent=2))
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def time_decay_weight(timestamp_str: str) -> float:
|
|
134
|
+
"""Compute time decay weight for a feedback entry.
|
|
135
|
+
|
|
136
|
+
2026 Upgrade: Supports both step decay and exponential decay.
|
|
137
|
+
Exponential decay uses half-life formula: weight = 2^(-age/half_life)
|
|
138
|
+
"""
|
|
139
|
+
try:
|
|
140
|
+
ts_clean = timestamp_str.replace("Z", "").split("+")[0]
|
|
141
|
+
entry_time = datetime.fromisoformat(ts_clean)
|
|
142
|
+
except (ValueError, AttributeError):
|
|
143
|
+
return DECAY_WEIGHTS[None]
|
|
144
|
+
|
|
145
|
+
age_days = (datetime.now() - entry_time).days
|
|
146
|
+
|
|
147
|
+
if USE_EXPONENTIAL_DECAY:
|
|
148
|
+
# Exponential decay: weight = 2^(-age/half_life)
|
|
149
|
+
# At age=0: weight=1.0, at age=half_life: weight=0.5, etc.
|
|
150
|
+
weight = 2 ** (-age_days / HALF_LIFE_DAYS)
|
|
151
|
+
return max(weight, 0.01) # Floor at 1% to prevent zero weights
|
|
152
|
+
else:
|
|
153
|
+
# Legacy step decay
|
|
154
|
+
for threshold, weight in sorted(DECAY_WEIGHTS.items(), key=lambda x: (x[0] is None, x[0])):
|
|
155
|
+
if threshold is not None and age_days < threshold:
|
|
156
|
+
return weight
|
|
157
|
+
return DECAY_WEIGHTS[None]
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def classify_entry(entry: Dict, categories: Dict) -> List[str]:
|
|
161
|
+
"""Classify a feedback entry into categories based on keywords/tools."""
|
|
162
|
+
matched = []
|
|
163
|
+
|
|
164
|
+
# Build searchable text from entry
|
|
165
|
+
context = (entry.get("context", "") or "").lower()
|
|
166
|
+
message = (entry.get("message", "") or "").lower()
|
|
167
|
+
last_action = (entry.get("last_action", "") or "").lower()
|
|
168
|
+
last_tool = (entry.get("last_tool", "") or "").lower()
|
|
169
|
+
tags = entry.get("tags", [])
|
|
170
|
+
if isinstance(tags, list):
|
|
171
|
+
tags_str = " ".join(t.lower() for t in tags)
|
|
172
|
+
else:
|
|
173
|
+
tags_str = ""
|
|
174
|
+
|
|
175
|
+
searchable = f"{context} {message} {last_action} {tags_str}"
|
|
176
|
+
|
|
177
|
+
for cat_name, cat_config in categories.items():
|
|
178
|
+
keywords = cat_config.get("keywords", [])
|
|
179
|
+
tools = cat_config.get("tools", [])
|
|
180
|
+
|
|
181
|
+
# Check keyword match
|
|
182
|
+
keyword_match = any(kw.lower() in searchable for kw in keywords)
|
|
183
|
+
|
|
184
|
+
# Check tool match
|
|
185
|
+
tool_match = any(t.lower() in last_tool for t in tools) if tools else False
|
|
186
|
+
|
|
187
|
+
if keyword_match or tool_match:
|
|
188
|
+
matched.append(cat_name)
|
|
189
|
+
|
|
190
|
+
return matched if matched else ["uncategorized"]
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def load_feedback_entries() -> List[Dict]:
|
|
194
|
+
"""Load all feedback entries from JSONL."""
|
|
195
|
+
if not FEEDBACK_LOG.exists():
|
|
196
|
+
return []
|
|
197
|
+
|
|
198
|
+
entries = []
|
|
199
|
+
with open(FEEDBACK_LOG) as f:
|
|
200
|
+
for line in f:
|
|
201
|
+
line = line.strip()
|
|
202
|
+
if not line:
|
|
203
|
+
continue
|
|
204
|
+
try:
|
|
205
|
+
entries.append(json.loads(line))
|
|
206
|
+
except json.JSONDecodeError:
|
|
207
|
+
continue
|
|
208
|
+
return entries
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def is_positive(entry: Dict) -> bool:
|
|
212
|
+
"""Determine if a feedback entry is positive."""
|
|
213
|
+
if entry.get("reward", 0) > 0:
|
|
214
|
+
return True
|
|
215
|
+
# rlhf uses signal field: 'positive' or 'negative'
|
|
216
|
+
signal = entry.get("signal", "").lower()
|
|
217
|
+
if signal in ("positive", "up", "thumbsup"):
|
|
218
|
+
return True
|
|
219
|
+
feedback = entry.get("feedback", "").lower()
|
|
220
|
+
return feedback in ("positive", "up", "thumbsup")
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def train_full(categories: Dict) -> Dict:
|
|
224
|
+
"""Full rebuild: read all entries, compute posteriors."""
|
|
225
|
+
entries = load_feedback_entries()
|
|
226
|
+
model = create_initial_model(categories)
|
|
227
|
+
model["total_entries"] = len(entries)
|
|
228
|
+
|
|
229
|
+
# Ensure uncategorized exists
|
|
230
|
+
if "uncategorized" not in model["categories"]:
|
|
231
|
+
model["categories"]["uncategorized"] = {
|
|
232
|
+
"alpha": 1.0, "beta": 1.0, "samples": 0, "last_updated": None
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
for entry in entries:
|
|
236
|
+
weight = time_decay_weight(entry.get("timestamp", ""))
|
|
237
|
+
cats = classify_entry(entry, categories)
|
|
238
|
+
positive = is_positive(entry)
|
|
239
|
+
|
|
240
|
+
for cat in cats:
|
|
241
|
+
if cat not in model["categories"]:
|
|
242
|
+
model["categories"][cat] = {
|
|
243
|
+
"alpha": 1.0, "beta": 1.0, "samples": 0, "last_updated": None
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
if positive:
|
|
247
|
+
model["categories"][cat]["alpha"] += weight
|
|
248
|
+
else:
|
|
249
|
+
model["categories"][cat]["beta"] += weight
|
|
250
|
+
|
|
251
|
+
model["categories"][cat]["samples"] += 1
|
|
252
|
+
model["categories"][cat]["last_updated"] = entry.get("timestamp")
|
|
253
|
+
|
|
254
|
+
save_model(model)
|
|
255
|
+
return model
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def train_incremental(categories: Dict) -> Dict:
|
|
259
|
+
"""Incremental update: process only the latest entry."""
|
|
260
|
+
entries = load_feedback_entries()
|
|
261
|
+
if not entries:
|
|
262
|
+
return load_model()
|
|
263
|
+
|
|
264
|
+
model = load_model()
|
|
265
|
+
|
|
266
|
+
# Ensure all categories exist
|
|
267
|
+
for cat_name in categories:
|
|
268
|
+
if cat_name not in model["categories"]:
|
|
269
|
+
model["categories"][cat_name] = {
|
|
270
|
+
"alpha": 1.0, "beta": 1.0, "samples": 0, "last_updated": None
|
|
271
|
+
}
|
|
272
|
+
if "uncategorized" not in model["categories"]:
|
|
273
|
+
model["categories"]["uncategorized"] = {
|
|
274
|
+
"alpha": 1.0, "beta": 1.0, "samples": 0, "last_updated": None
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
latest = entries[-1]
|
|
278
|
+
weight = time_decay_weight(latest.get("timestamp", ""))
|
|
279
|
+
cats = classify_entry(latest, categories)
|
|
280
|
+
positive = is_positive(latest)
|
|
281
|
+
|
|
282
|
+
for cat in cats:
|
|
283
|
+
if cat not in model["categories"]:
|
|
284
|
+
model["categories"][cat] = {
|
|
285
|
+
"alpha": 1.0, "beta": 1.0, "samples": 0, "last_updated": None
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
if positive:
|
|
289
|
+
model["categories"][cat]["alpha"] += weight
|
|
290
|
+
else:
|
|
291
|
+
model["categories"][cat]["beta"] += weight
|
|
292
|
+
|
|
293
|
+
model["categories"][cat]["samples"] += 1
|
|
294
|
+
model["categories"][cat]["last_updated"] = latest.get("timestamp")
|
|
295
|
+
|
|
296
|
+
model["total_entries"] = len(entries)
|
|
297
|
+
save_model(model)
|
|
298
|
+
return model
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def compute_reliability(model: Dict) -> List[Tuple[str, float, float, float, int]]:
|
|
302
|
+
"""Compute reliability (posterior mean) for each category."""
|
|
303
|
+
results = []
|
|
304
|
+
for cat_name, params in model.get("categories", {}).items():
|
|
305
|
+
alpha = params["alpha"]
|
|
306
|
+
beta_val = params["beta"]
|
|
307
|
+
samples = params["samples"]
|
|
308
|
+
|
|
309
|
+
# Posterior mean of Beta distribution: alpha / (alpha + beta)
|
|
310
|
+
reliability = alpha / (alpha + beta_val) if (alpha + beta_val) > 0 else 0.5
|
|
311
|
+
|
|
312
|
+
# 95% credible interval width (approximate)
|
|
313
|
+
# For Beta(a,b): variance = ab / ((a+b)^2 * (a+b+1))
|
|
314
|
+
total = alpha + beta_val
|
|
315
|
+
if total > 0 and (total + 1) > 0:
|
|
316
|
+
variance = (alpha * beta_val) / (total * total * (total + 1))
|
|
317
|
+
ci_width = 2 * 1.96 * math.sqrt(variance)
|
|
318
|
+
else:
|
|
319
|
+
ci_width = 1.0
|
|
320
|
+
|
|
321
|
+
results.append((cat_name, alpha, beta_val, reliability, samples, ci_width))
|
|
322
|
+
|
|
323
|
+
return sorted(results, key=lambda x: -x[3])
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def sample_posteriors(model: Dict) -> Dict[str, float]:
|
|
327
|
+
"""Thompson Sampling: draw from each category's posterior."""
|
|
328
|
+
samples = {}
|
|
329
|
+
for cat_name, params in model.get("categories", {}).items():
|
|
330
|
+
alpha = max(params["alpha"], 0.01)
|
|
331
|
+
beta_val = max(params["beta"], 0.01)
|
|
332
|
+
samples[cat_name] = random.betavariate(alpha, beta_val)
|
|
333
|
+
return samples
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def save_snapshot(model: Dict) -> Path:
|
|
337
|
+
"""Save a timestamped snapshot for lift comparison."""
|
|
338
|
+
SNAPSHOTS_DIR.mkdir(parents=True, exist_ok=True)
|
|
339
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
340
|
+
snapshot_file = SNAPSHOTS_DIR / f"model_{timestamp}.json"
|
|
341
|
+
snapshot_file.write_text(json.dumps(model, indent=2))
|
|
342
|
+
return snapshot_file
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
# ============================================
|
|
346
|
+
# META-POLICY RULES (2026 Best Practice)
|
|
347
|
+
# Consolidate repeated mistakes into reusable rules
|
|
348
|
+
# Based on: Meta-Policy Reflexion (arXiv:2509.03990)
|
|
349
|
+
# ============================================
|
|
350
|
+
|
|
351
|
+
META_POLICY_FILE = PROJECT_ROOT / ".claude" / "memory" / "feedback" / "meta_policy_rules.json"
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def extract_meta_policy_rules(min_occurrences: int = 3) -> List[Dict[str, Any]]:
|
|
355
|
+
"""Extract reusable rules from repeated negative feedback patterns.
|
|
356
|
+
|
|
357
|
+
Feb 2026 Upgrade: Recency + intensity weighted confidence.
|
|
358
|
+
- Recent mistakes weigh more than old ones (exponential decay)
|
|
359
|
+
- High-intensity feedback (user frustration) boosts confidence faster
|
|
360
|
+
- Rules include trend analysis (improving vs deteriorating)
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
min_occurrences: Minimum times a pattern must appear to become a rule
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
List of meta-policy rules with condition, action, weighted confidence
|
|
367
|
+
"""
|
|
368
|
+
entries = load_feedback_entries()
|
|
369
|
+
negative_entries = [e for e in entries if not is_positive(e)]
|
|
370
|
+
|
|
371
|
+
if len(negative_entries) < min_occurrences:
|
|
372
|
+
return []
|
|
373
|
+
|
|
374
|
+
# Group by category
|
|
375
|
+
category_patterns: Dict[str, List[Dict]] = {}
|
|
376
|
+
for entry in negative_entries:
|
|
377
|
+
cats = classify_entry(entry, DEFAULT_CATEGORIES)
|
|
378
|
+
for cat in cats:
|
|
379
|
+
if cat not in category_patterns:
|
|
380
|
+
category_patterns[cat] = []
|
|
381
|
+
category_patterns[cat].append(entry)
|
|
382
|
+
|
|
383
|
+
# Also count positive entries per category for trend analysis
|
|
384
|
+
positive_entries = [e for e in entries if is_positive(e)]
|
|
385
|
+
category_positives: Dict[str, int] = {}
|
|
386
|
+
for entry in positive_entries:
|
|
387
|
+
cats = classify_entry(entry, DEFAULT_CATEGORIES)
|
|
388
|
+
for cat in cats:
|
|
389
|
+
category_positives[cat] = category_positives.get(cat, 0) + 1
|
|
390
|
+
|
|
391
|
+
rules = []
|
|
392
|
+
for category, patterns in category_patterns.items():
|
|
393
|
+
if len(patterns) >= min_occurrences:
|
|
394
|
+
# Feb 2026: Recency + intensity weighted confidence
|
|
395
|
+
weighted_sum = 0.0
|
|
396
|
+
total_weight = 0.0
|
|
397
|
+
recent_count = 0 # Last 7 days
|
|
398
|
+
recent_positive = 0
|
|
399
|
+
|
|
400
|
+
for e in patterns:
|
|
401
|
+
recency = time_decay_weight(e.get("timestamp", ""))
|
|
402
|
+
intensity = e.get("intensity", 3) / 5.0 # Normalize to 0-1
|
|
403
|
+
weight = recency * (0.5 + 0.5 * intensity) # Blend recency + intensity
|
|
404
|
+
weighted_sum += weight
|
|
405
|
+
total_weight += 1.0
|
|
406
|
+
|
|
407
|
+
# Track recent entries
|
|
408
|
+
try:
|
|
409
|
+
ts = e.get("timestamp", "").replace("Z", "").split("+")[0]
|
|
410
|
+
entry_time = datetime.fromisoformat(ts)
|
|
411
|
+
if (datetime.now() - entry_time).days <= 7:
|
|
412
|
+
recent_count += 1
|
|
413
|
+
except (ValueError, AttributeError):
|
|
414
|
+
pass
|
|
415
|
+
|
|
416
|
+
# Count recent positives for trend
|
|
417
|
+
for e in positive_entries:
|
|
418
|
+
cats = classify_entry(e, DEFAULT_CATEGORIES)
|
|
419
|
+
if category in cats:
|
|
420
|
+
try:
|
|
421
|
+
ts = e.get("timestamp", "").replace("Z", "").split("+")[0]
|
|
422
|
+
entry_time = datetime.fromisoformat(ts)
|
|
423
|
+
if (datetime.now() - entry_time).days <= 7:
|
|
424
|
+
recent_positive += 1
|
|
425
|
+
except (ValueError, AttributeError):
|
|
426
|
+
pass
|
|
427
|
+
|
|
428
|
+
# Weighted confidence: base + recency-weighted adjustment
|
|
429
|
+
avg_weighted = weighted_sum / total_weight if total_weight > 0 else 0
|
|
430
|
+
confidence = min(0.95, 0.4 + (avg_weighted * 0.3) + (len(patterns) * 0.05))
|
|
431
|
+
|
|
432
|
+
# Trend: improving or deteriorating
|
|
433
|
+
total_positives = category_positives.get(category, 0)
|
|
434
|
+
pos_ratio = total_positives / (total_positives + len(patterns)) if (total_positives + len(patterns)) > 0 else 0
|
|
435
|
+
if recent_count == 0 and recent_positive > 0:
|
|
436
|
+
trend = "improving"
|
|
437
|
+
elif recent_count > 2 and recent_positive == 0:
|
|
438
|
+
trend = "deteriorating"
|
|
439
|
+
elif recent_count > recent_positive:
|
|
440
|
+
trend = "needs_attention"
|
|
441
|
+
else:
|
|
442
|
+
trend = "stable"
|
|
443
|
+
|
|
444
|
+
rule = {
|
|
445
|
+
"id": f"rule_{category}_{len(patterns)}",
|
|
446
|
+
"category": category,
|
|
447
|
+
"occurrences": len(patterns),
|
|
448
|
+
"confidence": round(confidence, 3),
|
|
449
|
+
"weighted_confidence": round(avg_weighted, 4),
|
|
450
|
+
"trend": trend,
|
|
451
|
+
"recent_negatives_7d": recent_count,
|
|
452
|
+
"recent_positives_7d": recent_positive,
|
|
453
|
+
"positive_ratio": round(pos_ratio, 3),
|
|
454
|
+
"created": datetime.now().isoformat(),
|
|
455
|
+
"condition": f"When working on {category} tasks",
|
|
456
|
+
"action": f"Pay extra attention - {len(patterns)} past mistakes in this area",
|
|
457
|
+
"examples": [
|
|
458
|
+
e.get("context", e.get("message", ""))[:100]
|
|
459
|
+
for e in sorted(patterns, key=lambda x: x.get("timestamp", ""), reverse=True)[:3]
|
|
460
|
+
],
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
# Category-specific rules
|
|
464
|
+
if category == "git":
|
|
465
|
+
rule["action"] = "VERIFY git operations before executing - check branch, status, diff"
|
|
466
|
+
elif category == "code_edit":
|
|
467
|
+
rule["action"] = "READ the file first, understand context before editing"
|
|
468
|
+
elif category == "testing":
|
|
469
|
+
rule["action"] = "Run tests after changes, don't assume they pass"
|
|
470
|
+
elif category == "pr_review":
|
|
471
|
+
rule["action"] = "Address ALL review comments, don't just minimize"
|
|
472
|
+
elif category == "debugging":
|
|
473
|
+
rule["action"] = "Verify the fix actually works - don't claim success without evidence"
|
|
474
|
+
|
|
475
|
+
rules.append(rule)
|
|
476
|
+
|
|
477
|
+
# Sort by confidence descending (most urgent first)
|
|
478
|
+
rules.sort(key=lambda r: r["confidence"], reverse=True)
|
|
479
|
+
return rules
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def save_meta_policy_rules(rules: List[Dict[str, Any]]):
|
|
483
|
+
"""Save extracted rules to disk."""
|
|
484
|
+
META_POLICY_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
485
|
+
with open(META_POLICY_FILE, "w") as f:
|
|
486
|
+
json.dump({
|
|
487
|
+
"updated": datetime.now().isoformat(),
|
|
488
|
+
"rule_count": len(rules),
|
|
489
|
+
"rules": rules,
|
|
490
|
+
}, f, indent=2)
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
def load_meta_policy_rules() -> List[Dict[str, Any]]:
|
|
494
|
+
"""Load existing meta-policy rules."""
|
|
495
|
+
if not META_POLICY_FILE.exists():
|
|
496
|
+
return []
|
|
497
|
+
try:
|
|
498
|
+
with open(META_POLICY_FILE) as f:
|
|
499
|
+
data = json.load(f)
|
|
500
|
+
return data.get("rules", [])
|
|
501
|
+
except (json.JSONDecodeError, KeyError):
|
|
502
|
+
return []
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
# ============================================
|
|
506
|
+
# DPO-STYLE BATCH OPTIMIZATION (Feb 2026)
|
|
507
|
+
# Direct Preference Optimization without explicit reward model.
|
|
508
|
+
# Builds preference pairs from positive/negative feedback,
|
|
509
|
+
# then adjusts category priors more aggressively than
|
|
510
|
+
# simple counting — mimicking DPO's closed-form update.
|
|
511
|
+
#
|
|
512
|
+
# Reference: Rafailov et al. 2023 (arXiv:2305.18290)
|
|
513
|
+
# ============================================
|
|
514
|
+
|
|
515
|
+
DPO_MODEL_FILE = PROJECT_ROOT / ".claude" / "memory" / "feedback" / "dpo_model.json"
|
|
516
|
+
DPO_BETA = 0.1 # Temperature parameter (lower = more aggressive preference following)
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
def _override_dpo_beta(value: float):
|
|
520
|
+
"""Override DPO_BETA at module level."""
|
|
521
|
+
global DPO_BETA
|
|
522
|
+
DPO_BETA = value
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def build_preference_pairs(categories: Dict) -> Dict[str, List[Tuple[Dict, Dict]]]:
|
|
526
|
+
"""Build (chosen, rejected) preference pairs per category.
|
|
527
|
+
|
|
528
|
+
For each category, pair the most recent positive entry with the most
|
|
529
|
+
recent negative entry. This creates implicit preference data without
|
|
530
|
+
needing explicit A/B comparisons.
|
|
531
|
+
"""
|
|
532
|
+
entries = load_feedback_entries()
|
|
533
|
+
if not entries:
|
|
534
|
+
return {}
|
|
535
|
+
|
|
536
|
+
# Classify entries by category and sentiment
|
|
537
|
+
cat_positives: Dict[str, List[Dict]] = {}
|
|
538
|
+
cat_negatives: Dict[str, List[Dict]] = {}
|
|
539
|
+
|
|
540
|
+
for entry in entries:
|
|
541
|
+
cats = classify_entry(entry, categories)
|
|
542
|
+
for cat in cats:
|
|
543
|
+
if is_positive(entry):
|
|
544
|
+
cat_positives.setdefault(cat, []).append(entry)
|
|
545
|
+
else:
|
|
546
|
+
cat_negatives.setdefault(cat, []).append(entry)
|
|
547
|
+
|
|
548
|
+
# Build pairs: each positive paired with closest-in-time negative
|
|
549
|
+
pairs: Dict[str, List[Tuple[Dict, Dict]]] = {}
|
|
550
|
+
all_cats = set(list(cat_positives.keys()) + list(cat_negatives.keys()))
|
|
551
|
+
|
|
552
|
+
for cat in all_cats:
|
|
553
|
+
pos = cat_positives.get(cat, [])
|
|
554
|
+
neg = cat_negatives.get(cat, [])
|
|
555
|
+
if not pos or not neg:
|
|
556
|
+
continue
|
|
557
|
+
|
|
558
|
+
cat_pairs = []
|
|
559
|
+
# Sort by timestamp
|
|
560
|
+
pos_sorted = sorted(pos, key=lambda e: e.get("timestamp", ""))
|
|
561
|
+
neg_sorted = sorted(neg, key=lambda e: e.get("timestamp", ""))
|
|
562
|
+
|
|
563
|
+
# Pair each positive with the nearest negative (greedy matching)
|
|
564
|
+
used_neg = set()
|
|
565
|
+
for p in pos_sorted:
|
|
566
|
+
best_neg = None
|
|
567
|
+
best_dist = float("inf")
|
|
568
|
+
for i, n in enumerate(neg_sorted):
|
|
569
|
+
if i in used_neg:
|
|
570
|
+
continue
|
|
571
|
+
try:
|
|
572
|
+
p_ts = datetime.fromisoformat(p.get("timestamp", "").replace("Z", "").split("+")[0])
|
|
573
|
+
n_ts = datetime.fromisoformat(n.get("timestamp", "").replace("Z", "").split("+")[0])
|
|
574
|
+
dist = abs((p_ts - n_ts).total_seconds())
|
|
575
|
+
except (ValueError, AttributeError):
|
|
576
|
+
dist = float("inf")
|
|
577
|
+
if dist < best_dist:
|
|
578
|
+
best_dist = dist
|
|
579
|
+
best_neg = i
|
|
580
|
+
if best_neg is not None:
|
|
581
|
+
used_neg.add(best_neg)
|
|
582
|
+
cat_pairs.append((p, neg_sorted[best_neg]))
|
|
583
|
+
|
|
584
|
+
if cat_pairs:
|
|
585
|
+
pairs[cat] = cat_pairs
|
|
586
|
+
|
|
587
|
+
return pairs
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def dpo_log_ratio(chosen_weight: float, rejected_weight: float, beta: float = DPO_BETA) -> float:
|
|
591
|
+
"""Compute DPO implicit reward difference.
|
|
592
|
+
|
|
593
|
+
DPO loss: -log(sigmoid(beta * (log pi(chosen) - log pi(rejected))))
|
|
594
|
+
We use time-decay weights as proxy for log-probabilities.
|
|
595
|
+
|
|
596
|
+
Returns adjustment to apply to category alpha/beta parameters.
|
|
597
|
+
"""
|
|
598
|
+
# Avoid log(0)
|
|
599
|
+
chosen_weight = max(chosen_weight, 0.01)
|
|
600
|
+
rejected_weight = max(rejected_weight, 0.01)
|
|
601
|
+
|
|
602
|
+
log_ratio = math.log(chosen_weight) - math.log(rejected_weight)
|
|
603
|
+
sigmoid = 1.0 / (1.0 + math.exp(-beta * log_ratio))
|
|
604
|
+
|
|
605
|
+
# Scale adjustment: larger preference gap → larger update
|
|
606
|
+
adjustment = (sigmoid - 0.5) * 2 # Range: -1 to 1
|
|
607
|
+
return adjustment
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
def train_dpo(categories: Dict) -> Dict:
|
|
611
|
+
"""DPO-style batch optimization (Feb 2026 upgrade).
|
|
612
|
+
|
|
613
|
+
Instead of simple counting, uses preference pairs to compute
|
|
614
|
+
direct policy updates. Works alongside Thompson Sampling:
|
|
615
|
+
- Thompson Sampling: online exploration (per-feedback updates)
|
|
616
|
+
- DPO: batch exploitation (accumulated preference pairs)
|
|
617
|
+
|
|
618
|
+
The DPO adjustment is applied on top of the Thompson model.
|
|
619
|
+
"""
|
|
620
|
+
pairs = build_preference_pairs(categories)
|
|
621
|
+
if not pairs:
|
|
622
|
+
print("No preference pairs found. Need both positive and negative feedback per category.")
|
|
623
|
+
return load_model()
|
|
624
|
+
|
|
625
|
+
model = load_model()
|
|
626
|
+
|
|
627
|
+
dpo_adjustments = {}
|
|
628
|
+
|
|
629
|
+
for cat, cat_pairs in pairs.items():
|
|
630
|
+
if cat not in model["categories"]:
|
|
631
|
+
continue
|
|
632
|
+
|
|
633
|
+
total_adjustment = 0.0
|
|
634
|
+
for chosen, rejected in cat_pairs:
|
|
635
|
+
chosen_weight = time_decay_weight(chosen.get("timestamp", ""))
|
|
636
|
+
rejected_weight = time_decay_weight(rejected.get("timestamp", ""))
|
|
637
|
+
|
|
638
|
+
# Compute DPO-style adjustment
|
|
639
|
+
adj = dpo_log_ratio(chosen_weight, rejected_weight)
|
|
640
|
+
total_adjustment += adj
|
|
641
|
+
|
|
642
|
+
# Average adjustment over all pairs
|
|
643
|
+
avg_adjustment = total_adjustment / len(cat_pairs) if cat_pairs else 0
|
|
644
|
+
|
|
645
|
+
# Apply DPO adjustment to model parameters
|
|
646
|
+
# Positive adjustment → boost alpha (more reliable)
|
|
647
|
+
# Negative adjustment → boost beta (less reliable)
|
|
648
|
+
if avg_adjustment > 0:
|
|
649
|
+
boost = avg_adjustment * len(cat_pairs) * 0.5 # Scale by pair count
|
|
650
|
+
model["categories"][cat]["alpha"] += boost
|
|
651
|
+
else:
|
|
652
|
+
penalty = abs(avg_adjustment) * len(cat_pairs) * 0.5
|
|
653
|
+
model["categories"][cat]["beta"] += penalty
|
|
654
|
+
|
|
655
|
+
dpo_adjustments[cat] = {
|
|
656
|
+
"pairs": len(cat_pairs),
|
|
657
|
+
"avg_adjustment": round(avg_adjustment, 4),
|
|
658
|
+
"direction": "boost" if avg_adjustment > 0 else "penalize",
|
|
659
|
+
}
|
|
660
|
+
|
|
661
|
+
# Save DPO metadata
|
|
662
|
+
dpo_meta = {
|
|
663
|
+
"updated": datetime.now().isoformat(),
|
|
664
|
+
"beta": DPO_BETA,
|
|
665
|
+
"total_pairs": sum(len(p) for p in pairs.values()),
|
|
666
|
+
"categories": dpo_adjustments,
|
|
667
|
+
}
|
|
668
|
+
DPO_MODEL_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
669
|
+
with open(DPO_MODEL_FILE, "w") as f:
|
|
670
|
+
json.dump(dpo_meta, f, indent=2)
|
|
671
|
+
|
|
672
|
+
save_model(model)
|
|
673
|
+
return model
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
def print_dpo_results(model: Dict):
|
|
677
|
+
"""Print DPO training results."""
|
|
678
|
+
if not DPO_MODEL_FILE.exists():
|
|
679
|
+
print("\nNo DPO model found. Run --dpo-train first.")
|
|
680
|
+
return
|
|
681
|
+
|
|
682
|
+
with open(DPO_MODEL_FILE) as f:
|
|
683
|
+
dpo_meta = json.load(f)
|
|
684
|
+
|
|
685
|
+
print()
|
|
686
|
+
print("=" * 60)
|
|
687
|
+
print("DPO BATCH OPTIMIZATION RESULTS (Feb 2026)")
|
|
688
|
+
print("=" * 60)
|
|
689
|
+
print(f" Beta (temperature): {dpo_meta.get('beta', DPO_BETA)}")
|
|
690
|
+
print(f" Total preference pairs: {dpo_meta.get('total_pairs', 0)}")
|
|
691
|
+
print(f" Updated: {dpo_meta.get('updated', 'never')}")
|
|
692
|
+
print()
|
|
693
|
+
|
|
694
|
+
for cat, adj in sorted(
|
|
695
|
+
dpo_meta.get("categories", {}).items(),
|
|
696
|
+
key=lambda x: abs(x[1].get("avg_adjustment", 0)),
|
|
697
|
+
reverse=True,
|
|
698
|
+
):
|
|
699
|
+
direction = adj.get("direction", "none")
|
|
700
|
+
arrow = "+" if direction == "boost" else "-"
|
|
701
|
+
bar_val = abs(adj.get("avg_adjustment", 0)) * 10
|
|
702
|
+
bar = "#" * min(10, int(bar_val)) + "-" * max(0, 10 - int(bar_val))
|
|
703
|
+
print(f" {cat:<20s} [{bar}] {arrow}{abs(adj.get('avg_adjustment', 0)):.4f} ({adj.get('pairs', 0)} pairs)")
|
|
704
|
+
|
|
705
|
+
print()
|
|
706
|
+
print(" DPO adjusts Thompson Sampling priors based on preference pairs.")
|
|
707
|
+
print(" Run --reliability to see combined effect.")
|
|
708
|
+
print("=" * 60)
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
def print_meta_policy_rules():
|
|
712
|
+
"""Print meta-policy rules for session context."""
|
|
713
|
+
rules = load_meta_policy_rules()
|
|
714
|
+
|
|
715
|
+
print()
|
|
716
|
+
print("=" * 60)
|
|
717
|
+
print("META-POLICY RULES (Recency + Intensity Weighted)")
|
|
718
|
+
print("=" * 60)
|
|
719
|
+
|
|
720
|
+
if not rules:
|
|
721
|
+
print("\n No rules extracted yet. Need more feedback data.")
|
|
722
|
+
print(" Run --extract-rules after accumulating feedback.")
|
|
723
|
+
else:
|
|
724
|
+
for rule in rules:
|
|
725
|
+
conf_bar = "#" * int(rule["confidence"] * 10)
|
|
726
|
+
trend = rule.get("trend", "unknown")
|
|
727
|
+
trend_icon = {"improving": "+", "deteriorating": "!", "needs_attention": "?", "stable": "="}
|
|
728
|
+
trend_char = trend_icon.get(trend, "?")
|
|
729
|
+
print(f"\n [{rule['category'].upper()}] Confidence: [{conf_bar}] {rule['confidence']:.0%} (trend: {trend_char} {trend})")
|
|
730
|
+
print(f" Condition: {rule['condition']}")
|
|
731
|
+
print(f" Action: {rule['action']}")
|
|
732
|
+
print(f" Based on: {rule['occurrences']} negatives | Positive ratio: {rule.get('positive_ratio', 0):.0%}")
|
|
733
|
+
recent_neg = rule.get("recent_negatives_7d", 0)
|
|
734
|
+
recent_pos = rule.get("recent_positives_7d", 0)
|
|
735
|
+
if recent_neg or recent_pos:
|
|
736
|
+
print(f" Last 7d: {recent_neg} neg / {recent_pos} pos")
|
|
737
|
+
|
|
738
|
+
print("\n" + "=" * 60)
|
|
739
|
+
|
|
740
|
+
|
|
741
|
+
def print_reliability_table(model: Dict):
|
|
742
|
+
"""Print formatted reliability table."""
|
|
743
|
+
results = compute_reliability(model)
|
|
744
|
+
|
|
745
|
+
print()
|
|
746
|
+
print("=" * 78)
|
|
747
|
+
print("THOMPSON SAMPLING RELIABILITY TABLE")
|
|
748
|
+
print("=" * 78)
|
|
749
|
+
print()
|
|
750
|
+
print(f" Model updated: {model.get('updated', 'never')}")
|
|
751
|
+
print(f" Total entries: {model.get('total_entries', 0)}")
|
|
752
|
+
print()
|
|
753
|
+
print(f" {'Category':<20s} | {'Alpha':>7s} | {'Beta':>7s} | {'Reliability':>12s} | {'Samples':>7s} | {'CI Width':>8s}")
|
|
754
|
+
print(" " + "-" * 74)
|
|
755
|
+
|
|
756
|
+
for cat, alpha, beta_val, reliability, samples, ci_width in results:
|
|
757
|
+
# Visual bar
|
|
758
|
+
bar_len = int(reliability * 10)
|
|
759
|
+
bar = "#" * bar_len + "-" * (10 - bar_len)
|
|
760
|
+
|
|
761
|
+
print(f" {cat:<20s} | {alpha:>7.1f} | {beta_val:>7.1f} | [{bar}] {reliability:>4.0%} | {samples:>7d} | {ci_width:>7.3f}")
|
|
762
|
+
|
|
763
|
+
print()
|
|
764
|
+
print("=" * 78)
|
|
765
|
+
|
|
766
|
+
# Summary
|
|
767
|
+
if results:
|
|
768
|
+
best = results[0]
|
|
769
|
+
worst = results[-1]
|
|
770
|
+
print(f" Best: {best[0]} ({best[3]:.0%})")
|
|
771
|
+
print(f" Worst: {worst[0]} ({worst[3]:.0%})")
|
|
772
|
+
print()
|
|
773
|
+
|
|
774
|
+
# Categories needing attention (reliability < 50% with 3+ samples)
|
|
775
|
+
weak = [r for r in results if r[3] < 0.5 and r[4] >= 3]
|
|
776
|
+
if weak:
|
|
777
|
+
print(" Categories needing improvement:")
|
|
778
|
+
for cat, _, _, rel, samp, _ in weak:
|
|
779
|
+
print(f" - {cat}: {rel:.0%} ({samp} samples)")
|
|
780
|
+
print()
|
|
781
|
+
|
|
782
|
+
print("=" * 78)
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
def print_samples(model: Dict):
|
|
786
|
+
"""Print Thompson-sampled probabilities."""
|
|
787
|
+
samples = sample_posteriors(model)
|
|
788
|
+
|
|
789
|
+
print()
|
|
790
|
+
print("=" * 50)
|
|
791
|
+
print("THOMPSON SAMPLING (Single Draw)")
|
|
792
|
+
print("=" * 50)
|
|
793
|
+
print()
|
|
794
|
+
|
|
795
|
+
for cat, prob in sorted(samples.items(), key=lambda x: -x[1]):
|
|
796
|
+
bar = "#" * int(prob * 20) + "-" * (20 - int(prob * 20))
|
|
797
|
+
print(f" {cat:<20s} [{bar}] {prob:.3f}")
|
|
798
|
+
|
|
799
|
+
print()
|
|
800
|
+
print(" (Each run produces different samples - this is expected)")
|
|
801
|
+
print("=" * 50)
|
|
802
|
+
|
|
803
|
+
|
|
804
|
+
def main():
|
|
805
|
+
parser = argparse.ArgumentParser(description="Thompson Sampling Feedback Model Trainer (2026)")
|
|
806
|
+
parser.add_argument("--train", action="store_true", help="Full rebuild from JSONL")
|
|
807
|
+
parser.add_argument("--incremental", action="store_true", help="Update with latest entry")
|
|
808
|
+
parser.add_argument("--reliability", action="store_true", help="Print reliability table")
|
|
809
|
+
parser.add_argument("--sample", action="store_true", help="Sample from posteriors")
|
|
810
|
+
parser.add_argument("--snapshot", action="store_true", help="Save model snapshot")
|
|
811
|
+
parser.add_argument("--extract-rules", action="store_true", help="Extract meta-policy rules (2026)")
|
|
812
|
+
parser.add_argument("--show-rules", action="store_true", help="Show meta-policy rules")
|
|
813
|
+
parser.add_argument("--dpo-train", action="store_true", help="DPO batch optimization (Feb 2026)")
|
|
814
|
+
parser.add_argument("--dpo-beta", type=float, default=DPO_BETA, help="DPO temperature parameter")
|
|
815
|
+
parser.add_argument("--config", type=str, help="Path to custom categories JSON")
|
|
816
|
+
parser.add_argument("--json", action="store_true", help="Output as JSON (for hook consumption)")
|
|
817
|
+
|
|
818
|
+
args = parser.parse_args()
|
|
819
|
+
|
|
820
|
+
categories = load_config(args.config)
|
|
821
|
+
|
|
822
|
+
if args.train:
|
|
823
|
+
model = train_full(categories)
|
|
824
|
+
# Auto-run DPO batch optimization on full train (Feb 2026: autonomous)
|
|
825
|
+
dpo_model = train_dpo(categories)
|
|
826
|
+
# Auto-extract meta-policy rules with recency+intensity weighting
|
|
827
|
+
rules = extract_meta_policy_rules()
|
|
828
|
+
save_meta_policy_rules(rules)
|
|
829
|
+
if args.json:
|
|
830
|
+
print(json.dumps({"status": "trained", "entries": model["total_entries"], "dpo": True, "rules": len(rules)}))
|
|
831
|
+
else:
|
|
832
|
+
print(f"Trained model from {model['total_entries']} entries.")
|
|
833
|
+
print(f"DPO batch optimization applied. Meta-policy rules: {len(rules)}.")
|
|
834
|
+
print(f"Saved to: {MODEL_FILE}")
|
|
835
|
+
print_reliability_table(dpo_model)
|
|
836
|
+
|
|
837
|
+
elif args.incremental:
|
|
838
|
+
model = train_incremental(categories)
|
|
839
|
+
if args.json:
|
|
840
|
+
print(json.dumps({"status": "updated", "entries": model["total_entries"]}))
|
|
841
|
+
else:
|
|
842
|
+
print(f"Incremental update complete. Total entries: {model['total_entries']}")
|
|
843
|
+
|
|
844
|
+
elif args.reliability:
|
|
845
|
+
model = load_model()
|
|
846
|
+
if args.json:
|
|
847
|
+
results = compute_reliability(model)
|
|
848
|
+
output = {
|
|
849
|
+
"updated": model.get("updated"),
|
|
850
|
+
"total_entries": model.get("total_entries", 0),
|
|
851
|
+
"categories": {
|
|
852
|
+
cat: {"alpha": a, "beta": b, "reliability": r, "samples": s, "ci_width": ci}
|
|
853
|
+
for cat, a, b, r, s, ci in results
|
|
854
|
+
},
|
|
855
|
+
}
|
|
856
|
+
print(json.dumps(output, indent=2))
|
|
857
|
+
else:
|
|
858
|
+
print_reliability_table(model)
|
|
859
|
+
|
|
860
|
+
elif args.sample:
|
|
861
|
+
model = load_model()
|
|
862
|
+
if args.json:
|
|
863
|
+
samples = sample_posteriors(model)
|
|
864
|
+
print(json.dumps(samples, indent=2))
|
|
865
|
+
else:
|
|
866
|
+
print_samples(model)
|
|
867
|
+
|
|
868
|
+
elif args.snapshot:
|
|
869
|
+
model = load_model()
|
|
870
|
+
snapshot_file = save_snapshot(model)
|
|
871
|
+
if args.json:
|
|
872
|
+
print(json.dumps({"snapshot": str(snapshot_file)}))
|
|
873
|
+
else:
|
|
874
|
+
print(f"Snapshot saved: {snapshot_file}")
|
|
875
|
+
|
|
876
|
+
elif args.extract_rules:
|
|
877
|
+
rules = extract_meta_policy_rules()
|
|
878
|
+
save_meta_policy_rules(rules)
|
|
879
|
+
if args.json:
|
|
880
|
+
print(json.dumps({"status": "extracted", "rule_count": len(rules), "rules": rules}))
|
|
881
|
+
else:
|
|
882
|
+
print(f"Extracted {len(rules)} meta-policy rules.")
|
|
883
|
+
print(f"Saved to: {META_POLICY_FILE}")
|
|
884
|
+
print_meta_policy_rules()
|
|
885
|
+
|
|
886
|
+
elif args.show_rules:
|
|
887
|
+
if args.json:
|
|
888
|
+
rules = load_meta_policy_rules()
|
|
889
|
+
print(json.dumps({"rules": rules}, indent=2))
|
|
890
|
+
else:
|
|
891
|
+
print_meta_policy_rules()
|
|
892
|
+
|
|
893
|
+
elif args.dpo_train:
|
|
894
|
+
# Override DPO_BETA via module-level reassignment
|
|
895
|
+
_override_dpo_beta(args.dpo_beta)
|
|
896
|
+
model = train_dpo(categories)
|
|
897
|
+
if args.json:
|
|
898
|
+
dpo_meta = {}
|
|
899
|
+
if DPO_MODEL_FILE.exists():
|
|
900
|
+
with open(DPO_MODEL_FILE) as f:
|
|
901
|
+
dpo_meta = json.load(f)
|
|
902
|
+
print(json.dumps({"status": "dpo_trained", **dpo_meta}))
|
|
903
|
+
else:
|
|
904
|
+
print(f"DPO batch optimization complete.")
|
|
905
|
+
print(f"Saved to: {DPO_MODEL_FILE}")
|
|
906
|
+
print_dpo_results(model)
|
|
907
|
+
print_reliability_table(model)
|
|
908
|
+
|
|
909
|
+
else:
|
|
910
|
+
parser.print_help()
|
|
911
|
+
|
|
912
|
+
|
|
913
|
+
if __name__ == "__main__":
|
|
914
|
+
main()
|