qmdr 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. package/.claude-plugin/marketplace.json +29 -0
  2. package/.env.example +85 -0
  3. package/.gitattributes +3 -0
  4. package/.github/workflows/release.yml +77 -0
  5. package/AI-SETUP.md +466 -0
  6. package/LICENSE +22 -0
  7. package/README.md +78 -0
  8. package/bun.lock +637 -0
  9. package/docs/README-zh.md +78 -0
  10. package/docs/refactor-checklist.md +54 -0
  11. package/docs/setup-openclaw.md +139 -0
  12. package/example-index.yml +33 -0
  13. package/finetune/BALANCED_DISTRIBUTION.md +157 -0
  14. package/finetune/DATA_IMPROVEMENTS.md +218 -0
  15. package/finetune/Justfile +43 -0
  16. package/finetune/Modelfile +16 -0
  17. package/finetune/README.md +299 -0
  18. package/finetune/SCORING.md +286 -0
  19. package/finetune/configs/accelerate_multi_gpu.yaml +17 -0
  20. package/finetune/configs/grpo.yaml +49 -0
  21. package/finetune/configs/sft.yaml +42 -0
  22. package/finetune/configs/sft_local.yaml +40 -0
  23. package/finetune/convert_gguf.py +221 -0
  24. package/finetune/data/best_glm_prompt.txt +17 -0
  25. package/finetune/data/gepa_generated.prompts.json +32 -0
  26. package/finetune/data/qmd_expansion_balanced_deduped.jsonl +413 -0
  27. package/finetune/data/qmd_expansion_diverse_addon.jsonl +386 -0
  28. package/finetune/data/qmd_expansion_handcrafted.jsonl +65 -0
  29. package/finetune/data/qmd_expansion_handcrafted_only.jsonl +336 -0
  30. package/finetune/data/qmd_expansion_locations.jsonl +64 -0
  31. package/finetune/data/qmd_expansion_people.jsonl +46 -0
  32. package/finetune/data/qmd_expansion_short_nontech.jsonl +200 -0
  33. package/finetune/data/qmd_expansion_v2.jsonl +1498 -0
  34. package/finetune/data/qmd_only_sampled.jsonl +399 -0
  35. package/finetune/dataset/analyze_data.py +369 -0
  36. package/finetune/dataset/clean_data.py +906 -0
  37. package/finetune/dataset/generate_balanced.py +823 -0
  38. package/finetune/dataset/generate_data.py +714 -0
  39. package/finetune/dataset/generate_data_offline.py +206 -0
  40. package/finetune/dataset/generate_diverse.py +441 -0
  41. package/finetune/dataset/generate_ollama.py +326 -0
  42. package/finetune/dataset/prepare_data.py +197 -0
  43. package/finetune/dataset/schema.py +73 -0
  44. package/finetune/dataset/score_data.py +115 -0
  45. package/finetune/dataset/validate_schema.py +104 -0
  46. package/finetune/eval.py +196 -0
  47. package/finetune/evals/queries.txt +56 -0
  48. package/finetune/gepa/__init__.py +1 -0
  49. package/finetune/gepa/best_prompt.txt +31 -0
  50. package/finetune/gepa/best_prompt_glm.txt +1 -0
  51. package/finetune/gepa/dspy_gepa.py +204 -0
  52. package/finetune/gepa/example.py +117 -0
  53. package/finetune/gepa/generate.py +129 -0
  54. package/finetune/gepa/gepa_outputs.jsonl +10 -0
  55. package/finetune/gepa/gepa_outputs_glm.jsonl +20 -0
  56. package/finetune/gepa/model.json +19 -0
  57. package/finetune/gepa/optimizer.py +70 -0
  58. package/finetune/gepa/score.py +84 -0
  59. package/finetune/jobs/eval.py +490 -0
  60. package/finetune/jobs/eval_common.py +354 -0
  61. package/finetune/jobs/eval_verbose.py +113 -0
  62. package/finetune/jobs/grpo.py +141 -0
  63. package/finetune/jobs/quantize.py +244 -0
  64. package/finetune/jobs/sft.py +121 -0
  65. package/finetune/pyproject.toml +23 -0
  66. package/finetune/reward.py +610 -0
  67. package/finetune/train.py +611 -0
  68. package/finetune/uv.lock +4070 -0
  69. package/flake.lock +61 -0
  70. package/flake.nix +83 -0
  71. package/migrate-schema.ts +162 -0
  72. package/package.json +56 -0
  73. package/skills/qmdr/SKILL.md +172 -0
  74. package/skills/qmdr/references/mcp-setup.md +88 -0
  75. package/src/app/commands/collection.ts +55 -0
  76. package/src/app/commands/context.ts +82 -0
  77. package/src/app/commands/document.ts +46 -0
  78. package/src/app/commands/maintenance.ts +60 -0
  79. package/src/app/commands/search.ts +45 -0
  80. package/src/app/ports/llm.ts +13 -0
  81. package/src/app/services/llm-service.ts +145 -0
  82. package/src/cli.test.ts +963 -0
  83. package/src/collections.ts +390 -0
  84. package/src/eval.test.ts +412 -0
  85. package/src/formatter.ts +427 -0
  86. package/src/llm.test.ts +559 -0
  87. package/src/llm.ts +1990 -0
  88. package/src/mcp.test.ts +889 -0
  89. package/src/mcp.ts +626 -0
  90. package/src/qmd.ts +3330 -0
  91. package/src/store/collections.ts +7 -0
  92. package/src/store/context.ts +10 -0
  93. package/src/store/db.ts +5 -0
  94. package/src/store/documents.ts +26 -0
  95. package/src/store/maintenance.ts +15 -0
  96. package/src/store/path.ts +13 -0
  97. package/src/store/search.ts +10 -0
  98. package/src/store-paths.test.ts +395 -0
  99. package/src/store.test.ts +2483 -0
  100. package/src/store.ts +2813 -0
  101. package/test/eval-harness.ts +223 -0
  102. package/tsconfig.json +29 -0
@@ -0,0 +1,490 @@
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "transformers>=4.45.0",
5
+ # "peft>=0.7.0",
6
+ # "torch",
7
+ # "huggingface_hub>=0.20.0",
8
+ # "accelerate",
9
+ # ]
10
+ # ///
11
+ """
12
+ Evaluate QMD query expansion models on HuggingFace Jobs.
13
+
14
+ Self-contained script — inlines the reward function and test queries.
15
+
16
+ hf jobs uv run --flavor a10g-small --secrets HF_TOKEN --timeout 30m jobs/eval.py
17
+ hf jobs uv run --flavor a10g-small --secrets HF_TOKEN --timeout 30m jobs/eval.py -- --sft-only
18
+ """
19
+
20
+ import argparse
21
+ import csv
22
+ import io
23
+ import json
24
+ import os
25
+ import re
26
+ import sys
27
+ from collections import Counter
28
+
29
+ import torch
30
+ from huggingface_hub import HfApi, login
31
+ from peft import PeftModel
32
+ from transformers import AutoModelForCausalLM, AutoTokenizer
33
+
34
+ # --- Config ---
35
+ BASE_MODEL = "Qwen/Qwen3-1.7B"
36
+ SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
37
+ GRPO_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
38
+
39
+ # --- Test queries (inlined from evals/queries.txt) ---
40
+ QUERIES = [
41
+ # Technical documentation
42
+ "how to configure authentication",
43
+ "typescript async await",
44
+ "docker compose networking",
45
+ "git rebase vs merge",
46
+ "react useEffect cleanup",
47
+ # Short/ambiguous
48
+ "auth",
49
+ "config",
50
+ "setup",
51
+ "api",
52
+ # Named entities
53
+ "who is TDS motorsports",
54
+ "React hooks tutorial",
55
+ "Docker container networking",
56
+ "Kubernetes pod deployment",
57
+ "AWS Lambda functions",
58
+ # Personal notes / journals
59
+ "meeting notes project kickoff",
60
+ "ideas for new feature",
61
+ "todo list app architecture",
62
+ # Research / learning
63
+ "what is dependency injection",
64
+ "difference between sql and nosql",
65
+ "kubernetes vs docker swarm",
66
+ # Error/debugging
67
+ "connection timeout error",
68
+ "memory leak debugging",
69
+ "cors error fix",
70
+ # Temporal / recency
71
+ "recent news about Shopify",
72
+ "latest AI developments",
73
+ "best laptops right now",
74
+ "what changed in kubernetes latest version",
75
+ # Complex
76
+ "how to implement caching with redis in nodejs",
77
+ "best practices for api rate limiting",
78
+ "setting up ci cd pipeline with github actions",
79
+ ]
80
+
81
+ # =============================================================================
82
+ # Reward function (inlined from reward.py)
83
+ # =============================================================================
84
+
85
+ STOPWORDS = frozenset({
86
+ 'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
87
+ 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
88
+ })
89
+
90
+ KEY_TERM_STOPWORDS = frozenset({
91
+ 'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
92
+ 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
93
+ 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
94
+ })
95
+
96
+ GENERIC_LEX_PHRASES = frozenset({
97
+ 'find information about', 'search for', 'look up', 'get information',
98
+ 'learn about', 'information on', 'details about', 'find out about',
99
+ 'what is', 'how to', 'guide to', 'help with',
100
+ })
101
+
102
+ CHAT_TEMPLATE_TOKENS = frozenset({
103
+ '<|im_start|>', '<|im_end|>', '<|endoftext|>',
104
+ '\nassistant\n', '\nuser\n',
105
+ })
106
+
107
+
108
+ def parse_expansion(text):
109
+ result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
110
+ for line in text.strip().split("\n"):
111
+ line = line.strip()
112
+ if not line:
113
+ continue
114
+ if line.startswith("lex:"):
115
+ result["lex"].append(line[4:].strip())
116
+ elif line.startswith("vec:"):
117
+ result["vec"].append(line[4:].strip())
118
+ elif line.startswith("hyde:"):
119
+ result["hyde"].append(line[5:].strip())
120
+ else:
121
+ result["invalid"].append(line)
122
+ return result
123
+
124
+
125
+ def clean_model_output(text):
126
+ text = text.replace('<|im_end|>', '').strip()
127
+ used_thinking = '<think>' in text and '</think>' in text
128
+ if used_thinking:
129
+ text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
130
+ return text, used_thinking
131
+
132
+
133
+ def extract_named_entities(query):
134
+ entities = set()
135
+ words = query.split()
136
+ prev_was_entity = False
137
+ for i, word in enumerate(words):
138
+ clean = word.strip('.,!?:;()[]"\'')
139
+ if not clean:
140
+ prev_was_entity = False
141
+ continue
142
+ is_entity = False
143
+ if clean.isupper() and len(clean) >= 2:
144
+ entities.add(clean.lower()); is_entity = True
145
+ elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
146
+ entities.add(clean.lower()); is_entity = True
147
+ elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
148
+ entities.add(clean.lower()); is_entity = True
149
+ elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
150
+ entities.add(clean.lower()); is_entity = True
151
+ elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
152
+ entities.add(clean.lower()); is_entity = True
153
+ prev_was_entity = is_entity
154
+ return entities
155
+
156
+
157
+ def get_key_terms(query):
158
+ return set(query.lower().split()) - KEY_TERM_STOPWORDS
159
+
160
+
161
+ def lex_preserves_key_terms(lex_line, query):
162
+ key_terms = get_key_terms(query)
163
+ return not key_terms or bool(key_terms & set(lex_line.lower().split()))
164
+
165
+
166
+ def lex_preserves_entities(line, entities):
167
+ if not entities: return True
168
+ return any(e in line.lower() for e in entities)
169
+
170
+
171
+ def lex_is_generic(lex_line):
172
+ lower = lex_line.lower().strip()
173
+ for phrase in GENERIC_LEX_PHRASES:
174
+ if phrase in lower or lower.startswith(phrase.split()[0]):
175
+ remaining = lower
176
+ for word in phrase.split():
177
+ remaining = remaining.replace(word, '', 1).strip()
178
+ if len(remaining) < 3:
179
+ return True
180
+ return False
181
+
182
+
183
+ def word_set_distance(a, b):
184
+ return len(set(a.lower().split()) ^ set(b.lower().split()))
185
+
186
+
187
+ def is_diverse(a, b, min_distance=2):
188
+ a, b = a.lower().strip(), b.lower().strip()
189
+ if a == b or a in b or b in a: return False
190
+ return word_set_distance(a, b) >= min_distance
191
+
192
+
193
+ def echoes_query(expansion, query):
194
+ exp, q = expansion.lower().strip(), query.lower().strip()
195
+ return exp == q or (q in exp and len(exp) < len(q) + 10)
196
+
197
+
198
+ def word_repetition_penalty(text):
199
+ counts = Counter(re.findall(r'\b\w+\b', text.lower()))
200
+ return sum((c - 2) * 2 for w, c in counts.items()
201
+ if c >= 3 and w not in STOPWORDS and len(w) > 2)
202
+
203
+
204
+ def score_expansion_detailed(query, expansion):
205
+ text, used_thinking = clean_model_output(expansion.strip())
206
+ deductions = []
207
+
208
+ def _fail(reason):
209
+ return {
210
+ "format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
211
+ "think_bonus": 0, "total": 0, "max_possible": 100,
212
+ "percentage": 0.0, "rating": "Failed", "deductions": [reason],
213
+ }
214
+
215
+ if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
216
+ return _fail("CHAT TEMPLATE LEAKAGE")
217
+ for line in text.split("\n"):
218
+ line = line.strip()
219
+ if line and not line.startswith(("lex:", "vec:", "hyde:")):
220
+ return _fail(f"INVALID LINE: {line[:50]}")
221
+
222
+ parsed = parse_expansion(text)
223
+
224
+ format_score = 10
225
+ if parsed["lex"]: format_score += 10
226
+ else: deductions.append("missing lex:")
227
+ if parsed["vec"]: format_score += 10
228
+ else: deductions.append("missing vec:")
229
+
230
+ diversity_score = 0
231
+ types_present = sum(1 for t in ("lex", "vec") if parsed[t])
232
+ if types_present >= 2: diversity_score += 10
233
+ if len(parsed["lex"]) + len(parsed["vec"]) >= 2: diversity_score += 5
234
+ lex_div = 5
235
+ for i, a in enumerate(parsed["lex"]):
236
+ for b in parsed["lex"][i+1:]:
237
+ if not is_diverse(a, b, 2): lex_div -= 2
238
+ diversity_score += max(0, lex_div)
239
+ vec_div = 5
240
+ for i, a in enumerate(parsed["vec"]):
241
+ for b in parsed["vec"][i+1:]:
242
+ if not is_diverse(a, b, 3): vec_div -= 2
243
+ diversity_score += max(0, vec_div)
244
+ echo = 5
245
+ for exp in parsed["lex"] + parsed["vec"]:
246
+ if echoes_query(exp, query): echo -= 3
247
+ diversity_score += max(0, echo)
248
+
249
+ hyde_score = 0
250
+ if parsed["hyde"]:
251
+ hyde_text = parsed["hyde"][0]
252
+ hyde_score += 5
253
+ hyde_len = len(hyde_text)
254
+ if 50 <= hyde_len <= 200: hyde_score += 5
255
+ elif hyde_len < 50: hyde_score += 2
256
+ if "\n" not in hyde_text: hyde_score += 5
257
+ hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
258
+
259
+ quality_score = 5
260
+ if parsed["lex"] and parsed["vec"]:
261
+ avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
262
+ avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
263
+ if avg_lex <= avg_vec: quality_score += 5
264
+ if parsed["vec"]:
265
+ natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
266
+ quality_score += 5 if natural == len(parsed["vec"]) else 2
267
+ if parsed["lex"]:
268
+ with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
269
+ if with_terms == len(parsed["lex"]): quality_score += 5
270
+ elif with_terms > 0: quality_score += 2
271
+
272
+ entity_score = 0
273
+ entities = extract_named_entities(query)
274
+ if entities and parsed["lex"]:
275
+ with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
276
+ if with_entities == len(parsed["lex"]): entity_score += 15
277
+ elif with_entities > 0: entity_score += 5
278
+ else: entity_score -= 30
279
+ generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
280
+ if generic_count: entity_score -= generic_count * 15
281
+ if parsed["vec"]:
282
+ vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
283
+ if vec_with > 0: entity_score += 5
284
+ elif not entities:
285
+ entity_score = 10
286
+
287
+ think_bonus = 0 if used_thinking else 20
288
+ total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
289
+ max_possible = 140 if parsed["hyde"] else 120
290
+ percentage = max(0.0, min(100.0, total / max_possible * 100))
291
+
292
+ if percentage >= 80: rating = "Excellent"
293
+ elif percentage >= 60: rating = "Good"
294
+ elif percentage >= 40: rating = "Acceptable"
295
+ elif percentage >= 20: rating = "Poor"
296
+ else: rating = "Failed"
297
+
298
+ return {
299
+ "format": format_score, "diversity": diversity_score, "hyde": hyde_score,
300
+ "quality": quality_score, "entity": max(0, entity_score),
301
+ "think_bonus": think_bonus, "total": max(0, total),
302
+ "max_possible": max_possible, "percentage": round(percentage, 1),
303
+ "rating": rating, "deductions": deductions,
304
+ "entities_detected": list(entities) if entities else [],
305
+ }
306
+
307
+
308
+ # =============================================================================
309
+ # Model loading and generation
310
+ # =============================================================================
311
+
312
+ def load_model(base, sft=None, grpo=None):
313
+ print(f"Loading tokenizer from {base}...")
314
+ tokenizer = AutoTokenizer.from_pretrained(base)
315
+ if tokenizer.pad_token is None:
316
+ tokenizer.pad_token = tokenizer.eos_token
317
+
318
+ print(f"Loading base model {base}...")
319
+ model = AutoModelForCausalLM.from_pretrained(
320
+ base, torch_dtype=torch.bfloat16, device_map="auto",
321
+ )
322
+
323
+ if sft:
324
+ print(f"Loading and merging SFT adapter {sft}...")
325
+ model = PeftModel.from_pretrained(model, sft)
326
+ model = model.merge_and_unload()
327
+
328
+ if grpo:
329
+ print(f"Loading GRPO adapter {grpo}...")
330
+ model = PeftModel.from_pretrained(model, grpo)
331
+
332
+ model.eval()
333
+ return model, tokenizer
334
+
335
+
336
+ def generate_expansion(model, tokenizer, query, max_new_tokens=200):
337
+ messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
338
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
339
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
340
+
341
+ with torch.no_grad():
342
+ outputs = model.generate(
343
+ **inputs, max_new_tokens=max_new_tokens,
344
+ temperature=0.7, do_sample=True,
345
+ pad_token_id=tokenizer.pad_token_id,
346
+ eos_token_id=tokenizer.eos_token_id,
347
+ )
348
+
349
+ full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
350
+ if "\nassistant\n" in full_output:
351
+ expansion = full_output.split("\nassistant\n")[-1].strip()
352
+ elif "assistant\n" in full_output:
353
+ expansion = full_output.split("assistant\n")[-1].strip()
354
+ else:
355
+ expansion = full_output[len(prompt):].strip()
356
+
357
+ if "<think>" in expansion:
358
+ expansion = re.sub(r'<think>.*?</think>', '', expansion, flags=re.DOTALL).strip()
359
+ return expansion
360
+
361
+
362
+ # =============================================================================
363
+ # Main
364
+ # =============================================================================
365
+
366
+ def results_to_csv(results, label):
367
+ """Convert eval results to CSV string."""
368
+ buf = io.StringIO()
369
+ writer = csv.writer(buf)
370
+ writer.writerow([
371
+ "model", "query", "expansion", "score_pct", "rating",
372
+ "format", "diversity", "hyde", "quality", "entity", "think_bonus",
373
+ "total", "max_possible", "deductions",
374
+ ])
375
+ for r in results:
376
+ s = r["scores"]
377
+ writer.writerow([
378
+ label, r["query"], r["expansion"], s["percentage"], s["rating"],
379
+ s["format"], s["diversity"], s["hyde"], s["quality"], s["entity"],
380
+ s["think_bonus"], s["total"], s["max_possible"],
381
+ "; ".join(s.get("deductions", [])),
382
+ ])
383
+ return buf.getvalue()
384
+
385
+
386
+ def upload_csv(results, label, repo_id, api):
387
+ """Upload eval results CSV to HuggingFace Hub."""
388
+ csv_data = results_to_csv(results, label)
389
+ tag = label.split("/")[-1].replace(" ", "_").lower()
390
+ filename = f"eval_{tag}.csv"
391
+ print(f" Uploading {filename} to {repo_id}...")
392
+ api.upload_file(
393
+ path_or_fileobj=csv_data.encode("utf-8"),
394
+ path_in_repo=filename,
395
+ repo_id=repo_id,
396
+ repo_type="model",
397
+ )
398
+ print(f" Uploaded: https://huggingface.co/{repo_id}/blob/main/{filename}")
399
+
400
+
401
+ def evaluate_model(model, tokenizer, label):
402
+ print(f"\n{'='*70}")
403
+ print(f" EVALUATING: {label}")
404
+ print(f"{'='*70}")
405
+
406
+ results = []
407
+ for i, query in enumerate(QUERIES, 1):
408
+ expansion = generate_expansion(model, tokenizer, query)
409
+ scores = score_expansion_detailed(query, expansion)
410
+ results.append({"query": query, "expansion": expansion, "scores": scores})
411
+
412
+ marker = "+" if scores["percentage"] >= 80 else "-" if scores["percentage"] < 60 else "~"
413
+ print(f" [{marker}] {i:2d}/{len(QUERIES)} {scores['percentage']:5.1f}% {scores['rating']:10s} {query}")
414
+
415
+ avg = sum(r["scores"]["percentage"] for r in results) / len(results)
416
+ ratings = Counter(r["scores"]["rating"] for r in results)
417
+
418
+ print(f"\n {'─'*50}")
419
+ print(f" Average score: {avg:.1f}%")
420
+ print(f" Ratings:")
421
+ for rating in ["Excellent", "Good", "Acceptable", "Poor", "Failed"]:
422
+ count = ratings.get(rating, 0)
423
+ if count > 0:
424
+ print(f" {rating:10s}: {count:2d} {'█' * count}")
425
+
426
+ # Show worst queries
427
+ worst = sorted(results, key=lambda r: r["scores"]["percentage"])[:5]
428
+ print(f"\n Bottom 5:")
429
+ for r in worst:
430
+ print(f" {r['scores']['percentage']:5.1f}% {r['query']}")
431
+ if r["scores"]["deductions"]:
432
+ print(f" {', '.join(r['scores']['deductions'][:3])}")
433
+
434
+ return results, avg
435
+
436
+
437
+ def main():
438
+ parser = argparse.ArgumentParser()
439
+ parser.add_argument("--sft-only", action="store_true", help="Only evaluate SFT model")
440
+ parser.add_argument("--upload-repo", default="tobil/qmd-query-expansion-evals",
441
+ help="HF repo to upload CSV results")
442
+ args = parser.parse_args()
443
+
444
+ hf_token = os.environ.get("HF_TOKEN")
445
+ if hf_token:
446
+ login(token=hf_token)
447
+
448
+ api = HfApi()
449
+ api.create_repo(repo_id=args.upload_repo, repo_type="model", exist_ok=True)
450
+
451
+ # Evaluate SFT
452
+ model, tokenizer = load_model(BASE_MODEL, sft=SFT_MODEL)
453
+ sft_results, sft_avg = evaluate_model(model, tokenizer, f"SFT: {SFT_MODEL}")
454
+ upload_csv(sft_results, "sft", args.upload_repo, api)
455
+
456
+ if not args.sft_only:
457
+ # For GRPO: reload base, merge SFT, then load GRPO adapter
458
+ del model
459
+ torch.cuda.empty_cache()
460
+ model, tokenizer = load_model(BASE_MODEL, sft=SFT_MODEL, grpo=GRPO_MODEL)
461
+ grpo_results, grpo_avg = evaluate_model(model, tokenizer, f"GRPO: {GRPO_MODEL}")
462
+ upload_csv(grpo_results, "grpo", args.upload_repo, api)
463
+
464
+ # Upload combined comparison CSV
465
+ combined = results_to_csv(sft_results, "sft") + results_to_csv(grpo_results, "grpo").split("\n", 1)[1]
466
+ api.upload_file(
467
+ path_or_fileobj=combined.encode("utf-8"),
468
+ path_in_repo="eval_comparison.csv",
469
+ repo_id=args.upload_repo,
470
+ repo_type="model",
471
+ )
472
+ print(f" Uploaded: eval_comparison.csv")
473
+
474
+ # Comparison
475
+ print(f"\n{'='*70}")
476
+ print(f" COMPARISON")
477
+ print(f"{'='*70}")
478
+ print(f" SFT average: {sft_avg:.1f}%")
479
+ print(f" GRPO average: {grpo_avg:.1f}%")
480
+ print(f" Delta: {grpo_avg - sft_avg:+.1f}%")
481
+
482
+ improved = sum(1 for s, g in zip(sft_results, grpo_results)
483
+ if g["scores"]["percentage"] > s["scores"]["percentage"])
484
+ regressed = sum(1 for s, g in zip(sft_results, grpo_results)
485
+ if g["scores"]["percentage"] < s["scores"]["percentage"])
486
+ print(f" Improved: {improved}/{len(QUERIES)}, Regressed: {regressed}/{len(QUERIES)}")
487
+
488
+
489
+ if __name__ == "__main__":
490
+ main()