qmdr 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. package/.claude-plugin/marketplace.json +29 -0
  2. package/.env.example +85 -0
  3. package/.gitattributes +3 -0
  4. package/.github/workflows/release.yml +77 -0
  5. package/AI-SETUP.md +466 -0
  6. package/LICENSE +22 -0
  7. package/README.md +78 -0
  8. package/bun.lock +637 -0
  9. package/docs/README-zh.md +78 -0
  10. package/docs/refactor-checklist.md +54 -0
  11. package/docs/setup-openclaw.md +139 -0
  12. package/example-index.yml +33 -0
  13. package/finetune/BALANCED_DISTRIBUTION.md +157 -0
  14. package/finetune/DATA_IMPROVEMENTS.md +218 -0
  15. package/finetune/Justfile +43 -0
  16. package/finetune/Modelfile +16 -0
  17. package/finetune/README.md +299 -0
  18. package/finetune/SCORING.md +286 -0
  19. package/finetune/configs/accelerate_multi_gpu.yaml +17 -0
  20. package/finetune/configs/grpo.yaml +49 -0
  21. package/finetune/configs/sft.yaml +42 -0
  22. package/finetune/configs/sft_local.yaml +40 -0
  23. package/finetune/convert_gguf.py +221 -0
  24. package/finetune/data/best_glm_prompt.txt +17 -0
  25. package/finetune/data/gepa_generated.prompts.json +32 -0
  26. package/finetune/data/qmd_expansion_balanced_deduped.jsonl +413 -0
  27. package/finetune/data/qmd_expansion_diverse_addon.jsonl +386 -0
  28. package/finetune/data/qmd_expansion_handcrafted.jsonl +65 -0
  29. package/finetune/data/qmd_expansion_handcrafted_only.jsonl +336 -0
  30. package/finetune/data/qmd_expansion_locations.jsonl +64 -0
  31. package/finetune/data/qmd_expansion_people.jsonl +46 -0
  32. package/finetune/data/qmd_expansion_short_nontech.jsonl +200 -0
  33. package/finetune/data/qmd_expansion_v2.jsonl +1498 -0
  34. package/finetune/data/qmd_only_sampled.jsonl +399 -0
  35. package/finetune/dataset/analyze_data.py +369 -0
  36. package/finetune/dataset/clean_data.py +906 -0
  37. package/finetune/dataset/generate_balanced.py +823 -0
  38. package/finetune/dataset/generate_data.py +714 -0
  39. package/finetune/dataset/generate_data_offline.py +206 -0
  40. package/finetune/dataset/generate_diverse.py +441 -0
  41. package/finetune/dataset/generate_ollama.py +326 -0
  42. package/finetune/dataset/prepare_data.py +197 -0
  43. package/finetune/dataset/schema.py +73 -0
  44. package/finetune/dataset/score_data.py +115 -0
  45. package/finetune/dataset/validate_schema.py +104 -0
  46. package/finetune/eval.py +196 -0
  47. package/finetune/evals/queries.txt +56 -0
  48. package/finetune/gepa/__init__.py +1 -0
  49. package/finetune/gepa/best_prompt.txt +31 -0
  50. package/finetune/gepa/best_prompt_glm.txt +1 -0
  51. package/finetune/gepa/dspy_gepa.py +204 -0
  52. package/finetune/gepa/example.py +117 -0
  53. package/finetune/gepa/generate.py +129 -0
  54. package/finetune/gepa/gepa_outputs.jsonl +10 -0
  55. package/finetune/gepa/gepa_outputs_glm.jsonl +20 -0
  56. package/finetune/gepa/model.json +19 -0
  57. package/finetune/gepa/optimizer.py +70 -0
  58. package/finetune/gepa/score.py +84 -0
  59. package/finetune/jobs/eval.py +490 -0
  60. package/finetune/jobs/eval_common.py +354 -0
  61. package/finetune/jobs/eval_verbose.py +113 -0
  62. package/finetune/jobs/grpo.py +141 -0
  63. package/finetune/jobs/quantize.py +244 -0
  64. package/finetune/jobs/sft.py +121 -0
  65. package/finetune/pyproject.toml +23 -0
  66. package/finetune/reward.py +610 -0
  67. package/finetune/train.py +611 -0
  68. package/finetune/uv.lock +4070 -0
  69. package/flake.lock +61 -0
  70. package/flake.nix +83 -0
  71. package/migrate-schema.ts +162 -0
  72. package/package.json +56 -0
  73. package/skills/qmdr/SKILL.md +172 -0
  74. package/skills/qmdr/references/mcp-setup.md +88 -0
  75. package/src/app/commands/collection.ts +55 -0
  76. package/src/app/commands/context.ts +82 -0
  77. package/src/app/commands/document.ts +46 -0
  78. package/src/app/commands/maintenance.ts +60 -0
  79. package/src/app/commands/search.ts +45 -0
  80. package/src/app/ports/llm.ts +13 -0
  81. package/src/app/services/llm-service.ts +145 -0
  82. package/src/cli.test.ts +963 -0
  83. package/src/collections.ts +390 -0
  84. package/src/eval.test.ts +412 -0
  85. package/src/formatter.ts +427 -0
  86. package/src/llm.test.ts +559 -0
  87. package/src/llm.ts +1990 -0
  88. package/src/mcp.test.ts +889 -0
  89. package/src/mcp.ts +626 -0
  90. package/src/qmd.ts +3330 -0
  91. package/src/store/collections.ts +7 -0
  92. package/src/store/context.ts +10 -0
  93. package/src/store/db.ts +5 -0
  94. package/src/store/documents.ts +26 -0
  95. package/src/store/maintenance.ts +15 -0
  96. package/src/store/path.ts +13 -0
  97. package/src/store/search.ts +10 -0
  98. package/src/store-paths.test.ts +395 -0
  99. package/src/store.test.ts +2483 -0
  100. package/src/store.ts +2813 -0
  101. package/test/eval-harness.ts +223 -0
  102. package/tsconfig.json +29 -0
@@ -0,0 +1,354 @@
1
+ """
2
+ Common evaluation and reward scoring for QMD query expansion models.
3
+
4
+ Shared by sft.py and grpo.py for post-training evaluation.
5
+ """
6
+
7
+ import csv
8
+ import io
9
+ import re
10
+ from collections import Counter
11
+
12
+ import torch
13
+ from huggingface_hub import HfApi
14
+
15
+ # =============================================================================
16
+ # Reward function (single source of truth)
17
+ # =============================================================================
18
+
19
+ STOPWORDS = frozenset({
20
+ 'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
21
+ 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
22
+ })
23
+
24
+ KEY_TERM_STOPWORDS = frozenset({
25
+ 'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
26
+ 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
27
+ 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
28
+ })
29
+
30
+ GENERIC_LEX_PHRASES = frozenset({
31
+ 'find information about', 'search for', 'look up', 'get information',
32
+ 'learn about', 'information on', 'details about', 'find out about',
33
+ 'what is', 'how to', 'guide to', 'help with',
34
+ })
35
+
36
+ CHAT_TEMPLATE_TOKENS = frozenset({
37
+ '<|im_start|>', '<|im_end|>', '<|endoftext|>',
38
+ '\nassistant\n', '\nuser\n',
39
+ })
40
+
41
+
42
+ def parse_expansion(text):
43
+ result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
44
+ for line in text.strip().split("\n"):
45
+ line = line.strip()
46
+ if not line:
47
+ continue
48
+ if line.startswith("lex:"):
49
+ result["lex"].append(line[4:].strip())
50
+ elif line.startswith("vec:"):
51
+ result["vec"].append(line[4:].strip())
52
+ elif line.startswith("hyde:"):
53
+ result["hyde"].append(line[5:].strip())
54
+ else:
55
+ result["invalid"].append(line)
56
+ return result
57
+
58
+
59
+ def clean_model_output(text):
60
+ text = text.replace('<|im_end|>', '').strip()
61
+ used_thinking = '<think>' in text and '</think>' in text
62
+ if used_thinking:
63
+ text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
64
+ return text, used_thinking
65
+
66
+
67
+ def extract_named_entities(query):
68
+ entities = set()
69
+ words = query.split()
70
+ prev_was_entity = False
71
+ for i, word in enumerate(words):
72
+ clean = word.strip('.,!?:;()[]"\'')
73
+ if not clean:
74
+ prev_was_entity = False
75
+ continue
76
+ is_entity = False
77
+ if clean.isupper() and len(clean) >= 2:
78
+ entities.add(clean.lower()); is_entity = True
79
+ elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
80
+ entities.add(clean.lower()); is_entity = True
81
+ elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
82
+ entities.add(clean.lower()); is_entity = True
83
+ elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
84
+ entities.add(clean.lower()); is_entity = True
85
+ elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
86
+ entities.add(clean.lower()); is_entity = True
87
+ prev_was_entity = is_entity
88
+ return entities
89
+
90
+
91
+ def get_key_terms(query):
92
+ return set(query.lower().split()) - KEY_TERM_STOPWORDS
93
+
94
+
95
+ def lex_preserves_key_terms(lex_line, query):
96
+ key_terms = get_key_terms(query)
97
+ return not key_terms or bool(key_terms & set(lex_line.lower().split()))
98
+
99
+
100
+ def lex_preserves_entities(line, entities):
101
+ if not entities:
102
+ return True
103
+ return any(e in line.lower() for e in entities)
104
+
105
+
106
+ def lex_is_generic(lex_line):
107
+ lower = lex_line.lower().strip()
108
+ for phrase in GENERIC_LEX_PHRASES:
109
+ if phrase in lower or lower.startswith(phrase.split()[0]):
110
+ remaining = lower
111
+ for word in phrase.split():
112
+ remaining = remaining.replace(word, '', 1).strip()
113
+ if len(remaining) < 3:
114
+ return True
115
+ return False
116
+
117
+
118
+ def word_set_distance(a, b):
119
+ return len(set(a.lower().split()) ^ set(b.lower().split()))
120
+
121
+
122
+ def is_diverse(a, b, min_distance=2):
123
+ a, b = a.lower().strip(), b.lower().strip()
124
+ if a == b or a in b or b in a:
125
+ return False
126
+ return word_set_distance(a, b) >= min_distance
127
+
128
+
129
+ def echoes_query(expansion, query):
130
+ exp, q = expansion.lower().strip(), query.lower().strip()
131
+ return exp == q or (q in exp and len(exp) < len(q) + 10)
132
+
133
+
134
+ def word_repetition_penalty(text):
135
+ counts = Counter(re.findall(r'\b\w+\b', text.lower()))
136
+ return sum((c - 2) * 2 for w, c in counts.items()
137
+ if c >= 3 and w not in STOPWORDS and len(w) > 2)
138
+
139
+
140
+ def score_expansion(query, expansion):
141
+ """Score expansion as float in [0.0, 1.0] for RL reward."""
142
+ text, used_thinking = clean_model_output(expansion.strip())
143
+
144
+ if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
145
+ return 0.0
146
+ for line in text.split("\n"):
147
+ line = line.strip()
148
+ if line and not line.startswith(("lex:", "vec:", "hyde:")):
149
+ return 0.0
150
+
151
+ parsed = parse_expansion(text)
152
+
153
+ format_score = 10
154
+ if parsed["lex"]: format_score += 10
155
+ if parsed["vec"]: format_score += 10
156
+
157
+ diversity_score = 0
158
+ if sum(1 for t in ("lex", "vec") if parsed[t]) >= 2: diversity_score += 10
159
+ if len(parsed["lex"]) + len(parsed["vec"]) >= 2: diversity_score += 5
160
+ lex_div = 5
161
+ for i, a in enumerate(parsed["lex"]):
162
+ for b in parsed["lex"][i+1:]:
163
+ if not is_diverse(a, b, 2): lex_div -= 2
164
+ diversity_score += max(0, lex_div)
165
+ vec_div = 5
166
+ for i, a in enumerate(parsed["vec"]):
167
+ for b in parsed["vec"][i+1:]:
168
+ if not is_diverse(a, b, 3): vec_div -= 2
169
+ diversity_score += max(0, vec_div)
170
+ echo = 5
171
+ for exp in parsed["lex"] + parsed["vec"]:
172
+ if echoes_query(exp, query): echo -= 3
173
+ diversity_score += max(0, echo)
174
+
175
+ hyde_score = 0
176
+ if parsed["hyde"]:
177
+ hyde_text = parsed["hyde"][0]
178
+ hyde_score += 5
179
+ if 50 <= len(hyde_text) <= 200: hyde_score += 5
180
+ elif len(hyde_text) < 50: hyde_score += 2
181
+ if "\n" not in hyde_text: hyde_score += 5
182
+ hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
183
+
184
+ quality_score = 5
185
+ if parsed["lex"] and parsed["vec"]:
186
+ avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
187
+ avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
188
+ if avg_lex <= avg_vec: quality_score += 5
189
+ if parsed["vec"]:
190
+ natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
191
+ quality_score += 5 if natural == len(parsed["vec"]) else 2
192
+ if parsed["lex"]:
193
+ with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
194
+ if with_terms == len(parsed["lex"]): quality_score += 5
195
+ elif with_terms > 0: quality_score += 2
196
+
197
+ entity_score = 0
198
+ entities = extract_named_entities(query)
199
+ if entities and parsed["lex"]:
200
+ with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
201
+ if with_entities == len(parsed["lex"]): entity_score += 15
202
+ elif with_entities > 0: entity_score += 5
203
+ else: entity_score -= 30
204
+ generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
205
+ if generic_count: entity_score -= generic_count * 15
206
+ if parsed["vec"]:
207
+ vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
208
+ if vec_with > 0: entity_score += 5
209
+ elif not entities:
210
+ entity_score = 10
211
+
212
+ think_bonus = 0 if used_thinking else 20
213
+ total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
214
+ max_possible = 140 if parsed["hyde"] else 120
215
+ return max(0.0, min(1.0, total / max_possible))
216
+
217
+
218
+ def extract_query_from_prompt(prompt):
219
+ """Extract the search query from a formatted prompt string."""
220
+ if "Expand this search query:" in prompt:
221
+ query = prompt.split("Expand this search query:")[-1].strip()
222
+ if "<|im_end|>" in query:
223
+ query = query.split("<|im_end|>")[0].strip()
224
+ return query
225
+ return prompt.strip()
226
+
227
+
228
+ class QMDRewardFunction:
229
+ """Reward function wrapper for TRL's GRPOTrainer."""
230
+ __name__ = "qmd_scoring_reward"
231
+
232
+ def __call__(self, completions, prompts=None, **kwargs):
233
+ rewards = []
234
+ for i, completion in enumerate(completions):
235
+ query = ""
236
+ if prompts and i < len(prompts):
237
+ query = extract_query_from_prompt(prompts[i])
238
+ rewards.append(score_expansion(query, completion))
239
+ return rewards
240
+
241
+
242
+ # =============================================================================
243
+ # Evaluation
244
+ # =============================================================================
245
+
246
+ EVAL_QUERIES = [
247
+ # Technical documentation
248
+ "how to configure authentication",
249
+ "typescript async await",
250
+ "docker compose networking",
251
+ "git rebase vs merge",
252
+ "react useEffect cleanup",
253
+ # Short/ambiguous
254
+ "auth", "config", "setup", "api",
255
+ # Named entities
256
+ "who is TDS motorsports",
257
+ "React hooks tutorial",
258
+ "Docker container networking",
259
+ "Kubernetes pod deployment",
260
+ "AWS Lambda functions",
261
+ # Personal notes / journals
262
+ "meeting notes project kickoff",
263
+ "ideas for new feature",
264
+ "todo list app architecture",
265
+ # Research / learning
266
+ "what is dependency injection",
267
+ "difference between sql and nosql",
268
+ "kubernetes vs docker swarm",
269
+ # Error/debugging
270
+ "connection timeout error",
271
+ "memory leak debugging",
272
+ "cors error fix",
273
+ # Temporal / recency
274
+ "recent news about Shopify",
275
+ "latest AI developments",
276
+ "best laptops right now",
277
+ "what changed in kubernetes latest version",
278
+ # Complex
279
+ "how to implement caching with redis in nodejs",
280
+ "best practices for api rate limiting",
281
+ "setting up ci cd pipeline with github actions",
282
+ ]
283
+
284
+
285
+ def generate_expansion(model, tokenizer, query, max_new_tokens=200):
286
+ """Generate a query expansion using the model."""
287
+ messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
288
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
289
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
290
+ with torch.no_grad():
291
+ outputs = model.generate(
292
+ **inputs, max_new_tokens=max_new_tokens,
293
+ temperature=0.7, do_sample=True,
294
+ pad_token_id=tokenizer.pad_token_id,
295
+ eos_token_id=tokenizer.eos_token_id,
296
+ )
297
+ full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
298
+ if "\nassistant\n" in full_output:
299
+ return full_output.split("\nassistant\n")[-1].strip()
300
+ elif "assistant\n" in full_output:
301
+ return full_output.split("assistant\n")[-1].strip()
302
+ return full_output[len(prompt):].strip()
303
+
304
+
305
+ def run_eval(model, tokenizer, label, upload_repo="tobil/qmd-query-expansion-evals"):
306
+ """Evaluate model on EVAL_QUERIES, print results, upload CSV."""
307
+ api = HfApi()
308
+ api.create_repo(repo_id=upload_repo, repo_type="model", exist_ok=True)
309
+
310
+ print(f"\n{'='*70}")
311
+ print(f" EVALUATING: {label}")
312
+ print(f"{'='*70}")
313
+
314
+ results = []
315
+ for i, query in enumerate(EVAL_QUERIES, 1):
316
+ expansion = generate_expansion(model, tokenizer, query)
317
+ score = score_expansion(query, expansion)
318
+ pct = round(score * 100, 1)
319
+ rating = ("Excellent" if pct >= 80 else "Good" if pct >= 60
320
+ else "Acceptable" if pct >= 40 else "Poor" if pct >= 20 else "Failed")
321
+ marker = "+" if pct >= 80 else "-" if pct < 60 else "~"
322
+ print(f" [{marker}] {i:2d}/{len(EVAL_QUERIES)} {pct:5.1f}% {rating:10s} {query}")
323
+ results.append({"query": query, "expansion": expansion, "score": pct, "rating": rating})
324
+
325
+ avg = sum(r["score"] for r in results) / len(results)
326
+ ratings = Counter(r["rating"] for r in results)
327
+
328
+ print(f"\n {'─'*50}")
329
+ print(f" Average score: {avg:.1f}%")
330
+ for r in ["Excellent", "Good", "Acceptable", "Poor", "Failed"]:
331
+ c = ratings.get(r, 0)
332
+ if c:
333
+ print(f" {r:10s}: {c:2d} {'█' * c}")
334
+
335
+ worst = sorted(results, key=lambda r: r["score"])[:5]
336
+ print(f"\n Bottom 5:")
337
+ for r in worst:
338
+ print(f" {r['score']:5.1f}% {r['query']}")
339
+
340
+ buf = io.StringIO()
341
+ writer = csv.writer(buf)
342
+ writer.writerow(["model", "query", "expansion", "score_pct", "rating"])
343
+ for r in results:
344
+ writer.writerow([label, r["query"], r["expansion"], r["score"], r["rating"]])
345
+
346
+ filename = f"eval_{label}.csv"
347
+ print(f"\n Uploading {filename} to {upload_repo}...")
348
+ api.upload_file(
349
+ path_or_fileobj=buf.getvalue().encode("utf-8"),
350
+ path_in_repo=filename,
351
+ repo_id=upload_repo,
352
+ repo_type="model",
353
+ )
354
+ print(f" Done: https://huggingface.co/{upload_repo}/blob/main/{filename}")
@@ -0,0 +1,113 @@
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "transformers>=4.45.0",
5
+ # "peft>=0.7.0",
6
+ # "torch",
7
+ # "huggingface_hub>=0.20.0",
8
+ # "accelerate",
9
+ # ]
10
+ # ///
11
+ """
12
+ Verbose eval: prints the actual expansions for every query.
13
+
14
+ hf jobs uv run --flavor a10g-small --secrets HF_TOKEN --timeout 30m jobs/eval_verbose.py
15
+ """
16
+
17
+ import os
18
+ import re
19
+ import sys
20
+ from collections import Counter
21
+
22
+ import torch
23
+ from huggingface_hub import login
24
+ from peft import PeftModel
25
+ from transformers import AutoModelForCausalLM, AutoTokenizer
26
+
27
+ BASE_MODEL = "Qwen/Qwen3-1.7B"
28
+ SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
29
+ GRPO_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
30
+
31
+ QUERIES = [
32
+ "how to configure authentication",
33
+ "typescript async await",
34
+ "docker compose networking",
35
+ "git rebase vs merge",
36
+ "react useEffect cleanup",
37
+ "auth",
38
+ "config",
39
+ "setup",
40
+ "api",
41
+ "who is TDS motorsports",
42
+ "React hooks tutorial",
43
+ "Docker container networking",
44
+ "Kubernetes pod deployment",
45
+ "AWS Lambda functions",
46
+ "meeting notes project kickoff",
47
+ "ideas for new feature",
48
+ "todo list app architecture",
49
+ "what is dependency injection",
50
+ "difference between sql and nosql",
51
+ "kubernetes vs docker swarm",
52
+ "connection timeout error",
53
+ "memory leak debugging",
54
+ "cors error fix",
55
+ "recent news about Shopify",
56
+ "latest AI developments",
57
+ "best laptops right now",
58
+ "what changed in kubernetes latest version",
59
+ "how to implement caching with redis in nodejs",
60
+ "best practices for api rate limiting",
61
+ "setting up ci cd pipeline with github actions",
62
+ ]
63
+
64
+
65
+ def load_model(base, sft=None, grpo=None):
66
+ tokenizer = AutoTokenizer.from_pretrained(base)
67
+ if tokenizer.pad_token is None:
68
+ tokenizer.pad_token = tokenizer.eos_token
69
+ model = AutoModelForCausalLM.from_pretrained(base, torch_dtype=torch.bfloat16, device_map="auto")
70
+ if sft:
71
+ model = PeftModel.from_pretrained(model, sft)
72
+ model = model.merge_and_unload()
73
+ if grpo:
74
+ model = PeftModel.from_pretrained(model, grpo)
75
+ model.eval()
76
+ return model, tokenizer
77
+
78
+
79
+ def generate(model, tokenizer, query):
80
+ messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
81
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
82
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
83
+ with torch.no_grad():
84
+ out = model.generate(**inputs, max_new_tokens=200, temperature=0.7, do_sample=True,
85
+ pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
86
+ text = tokenizer.decode(out[0], skip_special_tokens=True)
87
+ if "\nassistant\n" in text:
88
+ text = text.split("\nassistant\n")[-1].strip()
89
+ elif "assistant\n" in text:
90
+ text = text.split("assistant\n")[-1].strip()
91
+ if "<think>" in text:
92
+ text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
93
+ return text
94
+
95
+
96
+ def main():
97
+ hf_token = os.environ.get("HF_TOKEN")
98
+ if hf_token:
99
+ login(token=hf_token)
100
+
101
+ print("Loading GRPO model...", file=sys.stderr)
102
+ model, tokenizer = load_model(BASE_MODEL, sft=SFT_MODEL, grpo=GRPO_MODEL)
103
+
104
+ for i, query in enumerate(QUERIES, 1):
105
+ expansion = generate(model, tokenizer, query)
106
+ print(f"\n{'='*60}")
107
+ print(f"[{i}/{len(QUERIES)}] {query}")
108
+ print(f"{'─'*60}")
109
+ print(expansion)
110
+
111
+
112
+ if __name__ == "__main__":
113
+ main()
@@ -0,0 +1,141 @@
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "trl>=0.12.0",
5
+ # "peft>=0.7.0",
6
+ # "transformers>=4.45.0",
7
+ # "accelerate>=0.24.0",
8
+ # "huggingface_hub>=0.20.0",
9
+ # "datasets",
10
+ # "bitsandbytes",
11
+ # "torch",
12
+ # ]
13
+ # ///
14
+ """
15
+ GRPO training for QMD query expansion (Qwen3-1.7B).
16
+
17
+ Runs on top of merged SFT weights. Self-contained for HuggingFace Jobs:
18
+ hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 4h jobs/grpo.py
19
+ """
20
+
21
+ import os
22
+ import sys
23
+
24
+ import torch
25
+ from datasets import load_dataset
26
+ from huggingface_hub import login
27
+ from peft import LoraConfig, PeftModel, get_peft_model
28
+ from transformers import AutoModelForCausalLM, AutoTokenizer
29
+ from trl import GRPOTrainer, GRPOConfig
30
+
31
+ # Download eval_common.py if running as a standalone script (e.g. HF Jobs)
32
+ _eval_common_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "eval_common.py")
33
+ if not os.path.exists(_eval_common_path):
34
+ import urllib.request
35
+ _url = "https://huggingface.co/datasets/tobil/hf-cli-jobs-uv-run-scripts/resolve/main/eval_common.py"
36
+ _opener = urllib.request.build_opener()
37
+ _token = os.environ.get("HF_TOKEN", "")
38
+ if _token:
39
+ _opener.addheaders = [("Authorization", f"Bearer {_token}")]
40
+ with open(_eval_common_path, "wb") as _f:
41
+ _f.write(_opener.open(_url).read())
42
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
43
+ from eval_common import QMDRewardFunction, run_eval
44
+
45
+ # --- Config (inlined from configs/grpo.yaml) ---
46
+ BASE_MODEL = "Qwen/Qwen3-1.7B"
47
+ SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
48
+ OUTPUT_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
49
+ DATASET = "tobil/qmd-query-expansion-train"
50
+
51
+
52
+ def main():
53
+ hf_token = os.environ.get("HF_TOKEN")
54
+ if hf_token:
55
+ login(token=hf_token)
56
+
57
+ print(f"Loading tokenizer from {BASE_MODEL}...")
58
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
59
+ if tokenizer.pad_token is None:
60
+ tokenizer.pad_token = tokenizer.eos_token
61
+
62
+ # Load and format dataset
63
+ print(f"Loading dataset: {DATASET}...")
64
+ dataset = load_dataset(DATASET, split="train")
65
+
66
+ def extract_prompt(example):
67
+ content = example["messages"][0]["content"]
68
+ messages = [{"role": "user", "content": content}]
69
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
70
+ return {"prompt": formatted}
71
+
72
+ dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
73
+ dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
74
+ print(f"Using {len(dataset)} prompts for GRPO")
75
+
76
+ # Load base model, merge SFT adapter
77
+ print(f"Loading base model {BASE_MODEL}...")
78
+ base_model = AutoModelForCausalLM.from_pretrained(
79
+ BASE_MODEL, torch_dtype=torch.bfloat16, device_map="auto",
80
+ )
81
+ print(f"Merging SFT adapter {SFT_MODEL}...")
82
+ model = PeftModel.from_pretrained(base_model, SFT_MODEL)
83
+ model = model.merge_and_unload()
84
+ print("SFT adapter merged.")
85
+
86
+ # Fresh LoRA for GRPO (small: rank 4, q/v only)
87
+ grpo_lora = LoraConfig(
88
+ r=4, lora_alpha=8, lora_dropout=0.05,
89
+ bias="none", task_type="CAUSAL_LM",
90
+ target_modules=["q_proj", "v_proj"],
91
+ )
92
+ model = get_peft_model(model, grpo_lora)
93
+ model.print_trainable_parameters()
94
+
95
+ config = GRPOConfig(
96
+ output_dir="qmd-query-expansion-1.7B-grpo",
97
+ push_to_hub=True,
98
+ hub_model_id=OUTPUT_MODEL,
99
+
100
+ num_generations=4,
101
+ max_completion_length=200,
102
+ beta=0.04, # KL regularization — prevents drift from SFT checkpoint
103
+
104
+ num_train_epochs=1,
105
+ per_device_train_batch_size=2,
106
+ gradient_accumulation_steps=8,
107
+ learning_rate=5e-7,
108
+ max_grad_norm=0.5,
109
+ max_steps=200,
110
+
111
+ logging_steps=10,
112
+ save_strategy="epoch",
113
+ bf16=True,
114
+
115
+ report_to="none",
116
+ )
117
+
118
+ print("Initializing GRPO trainer...")
119
+ trainer = GRPOTrainer(
120
+ model=model,
121
+ processing_class=tokenizer,
122
+ args=config,
123
+ train_dataset=dataset,
124
+ reward_funcs=[QMDRewardFunction()],
125
+ )
126
+
127
+ print("Starting GRPO training...")
128
+ trainer.train()
129
+
130
+ print("Pushing to Hub...")
131
+ trainer.push_to_hub()
132
+ print(f"Done! Model: https://huggingface.co/{OUTPUT_MODEL}")
133
+
134
+ # --- Automatic evaluation ---
135
+ print("\nStarting automatic evaluation...")
136
+ trainer.model.eval()
137
+ run_eval(trainer.model, tokenizer, "grpo")
138
+
139
+
140
+ if __name__ == "__main__":
141
+ main()