qmdr 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.claude-plugin/marketplace.json +29 -0
- package/.env.example +85 -0
- package/.gitattributes +3 -0
- package/.github/workflows/release.yml +77 -0
- package/AI-SETUP.md +466 -0
- package/LICENSE +22 -0
- package/README.md +78 -0
- package/bun.lock +637 -0
- package/docs/README-zh.md +78 -0
- package/docs/refactor-checklist.md +54 -0
- package/docs/setup-openclaw.md +139 -0
- package/example-index.yml +33 -0
- package/finetune/BALANCED_DISTRIBUTION.md +157 -0
- package/finetune/DATA_IMPROVEMENTS.md +218 -0
- package/finetune/Justfile +43 -0
- package/finetune/Modelfile +16 -0
- package/finetune/README.md +299 -0
- package/finetune/SCORING.md +286 -0
- package/finetune/configs/accelerate_multi_gpu.yaml +17 -0
- package/finetune/configs/grpo.yaml +49 -0
- package/finetune/configs/sft.yaml +42 -0
- package/finetune/configs/sft_local.yaml +40 -0
- package/finetune/convert_gguf.py +221 -0
- package/finetune/data/best_glm_prompt.txt +17 -0
- package/finetune/data/gepa_generated.prompts.json +32 -0
- package/finetune/data/qmd_expansion_balanced_deduped.jsonl +413 -0
- package/finetune/data/qmd_expansion_diverse_addon.jsonl +386 -0
- package/finetune/data/qmd_expansion_handcrafted.jsonl +65 -0
- package/finetune/data/qmd_expansion_handcrafted_only.jsonl +336 -0
- package/finetune/data/qmd_expansion_locations.jsonl +64 -0
- package/finetune/data/qmd_expansion_people.jsonl +46 -0
- package/finetune/data/qmd_expansion_short_nontech.jsonl +200 -0
- package/finetune/data/qmd_expansion_v2.jsonl +1498 -0
- package/finetune/data/qmd_only_sampled.jsonl +399 -0
- package/finetune/dataset/analyze_data.py +369 -0
- package/finetune/dataset/clean_data.py +906 -0
- package/finetune/dataset/generate_balanced.py +823 -0
- package/finetune/dataset/generate_data.py +714 -0
- package/finetune/dataset/generate_data_offline.py +206 -0
- package/finetune/dataset/generate_diverse.py +441 -0
- package/finetune/dataset/generate_ollama.py +326 -0
- package/finetune/dataset/prepare_data.py +197 -0
- package/finetune/dataset/schema.py +73 -0
- package/finetune/dataset/score_data.py +115 -0
- package/finetune/dataset/validate_schema.py +104 -0
- package/finetune/eval.py +196 -0
- package/finetune/evals/queries.txt +56 -0
- package/finetune/gepa/__init__.py +1 -0
- package/finetune/gepa/best_prompt.txt +31 -0
- package/finetune/gepa/best_prompt_glm.txt +1 -0
- package/finetune/gepa/dspy_gepa.py +204 -0
- package/finetune/gepa/example.py +117 -0
- package/finetune/gepa/generate.py +129 -0
- package/finetune/gepa/gepa_outputs.jsonl +10 -0
- package/finetune/gepa/gepa_outputs_glm.jsonl +20 -0
- package/finetune/gepa/model.json +19 -0
- package/finetune/gepa/optimizer.py +70 -0
- package/finetune/gepa/score.py +84 -0
- package/finetune/jobs/eval.py +490 -0
- package/finetune/jobs/eval_common.py +354 -0
- package/finetune/jobs/eval_verbose.py +113 -0
- package/finetune/jobs/grpo.py +141 -0
- package/finetune/jobs/quantize.py +244 -0
- package/finetune/jobs/sft.py +121 -0
- package/finetune/pyproject.toml +23 -0
- package/finetune/reward.py +610 -0
- package/finetune/train.py +611 -0
- package/finetune/uv.lock +4070 -0
- package/flake.lock +61 -0
- package/flake.nix +83 -0
- package/migrate-schema.ts +162 -0
- package/package.json +56 -0
- package/skills/qmdr/SKILL.md +172 -0
- package/skills/qmdr/references/mcp-setup.md +88 -0
- package/src/app/commands/collection.ts +55 -0
- package/src/app/commands/context.ts +82 -0
- package/src/app/commands/document.ts +46 -0
- package/src/app/commands/maintenance.ts +60 -0
- package/src/app/commands/search.ts +45 -0
- package/src/app/ports/llm.ts +13 -0
- package/src/app/services/llm-service.ts +145 -0
- package/src/cli.test.ts +963 -0
- package/src/collections.ts +390 -0
- package/src/eval.test.ts +412 -0
- package/src/formatter.ts +427 -0
- package/src/llm.test.ts +559 -0
- package/src/llm.ts +1990 -0
- package/src/mcp.test.ts +889 -0
- package/src/mcp.ts +626 -0
- package/src/qmd.ts +3330 -0
- package/src/store/collections.ts +7 -0
- package/src/store/context.ts +10 -0
- package/src/store/db.ts +5 -0
- package/src/store/documents.ts +26 -0
- package/src/store/maintenance.ts +15 -0
- package/src/store/path.ts +13 -0
- package/src/store/search.ts +10 -0
- package/src/store-paths.test.ts +395 -0
- package/src/store.test.ts +2483 -0
- package/src/store.ts +2813 -0
- package/test/eval-harness.ts +223 -0
- package/tsconfig.json +29 -0
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Common evaluation and reward scoring for QMD query expansion models.
|
|
3
|
+
|
|
4
|
+
Shared by sft.py and grpo.py for post-training evaluation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import csv
|
|
8
|
+
import io
|
|
9
|
+
import re
|
|
10
|
+
from collections import Counter
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from huggingface_hub import HfApi
|
|
14
|
+
|
|
15
|
+
# =============================================================================
|
|
16
|
+
# Reward function (single source of truth)
|
|
17
|
+
# =============================================================================
|
|
18
|
+
|
|
19
|
+
STOPWORDS = frozenset({
|
|
20
|
+
'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
|
|
21
|
+
'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
|
|
22
|
+
})
|
|
23
|
+
|
|
24
|
+
KEY_TERM_STOPWORDS = frozenset({
|
|
25
|
+
'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
|
|
26
|
+
'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
|
|
27
|
+
'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
|
|
28
|
+
})
|
|
29
|
+
|
|
30
|
+
GENERIC_LEX_PHRASES = frozenset({
|
|
31
|
+
'find information about', 'search for', 'look up', 'get information',
|
|
32
|
+
'learn about', 'information on', 'details about', 'find out about',
|
|
33
|
+
'what is', 'how to', 'guide to', 'help with',
|
|
34
|
+
})
|
|
35
|
+
|
|
36
|
+
CHAT_TEMPLATE_TOKENS = frozenset({
|
|
37
|
+
'<|im_start|>', '<|im_end|>', '<|endoftext|>',
|
|
38
|
+
'\nassistant\n', '\nuser\n',
|
|
39
|
+
})
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def parse_expansion(text):
|
|
43
|
+
result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
|
|
44
|
+
for line in text.strip().split("\n"):
|
|
45
|
+
line = line.strip()
|
|
46
|
+
if not line:
|
|
47
|
+
continue
|
|
48
|
+
if line.startswith("lex:"):
|
|
49
|
+
result["lex"].append(line[4:].strip())
|
|
50
|
+
elif line.startswith("vec:"):
|
|
51
|
+
result["vec"].append(line[4:].strip())
|
|
52
|
+
elif line.startswith("hyde:"):
|
|
53
|
+
result["hyde"].append(line[5:].strip())
|
|
54
|
+
else:
|
|
55
|
+
result["invalid"].append(line)
|
|
56
|
+
return result
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def clean_model_output(text):
|
|
60
|
+
text = text.replace('<|im_end|>', '').strip()
|
|
61
|
+
used_thinking = '<think>' in text and '</think>' in text
|
|
62
|
+
if used_thinking:
|
|
63
|
+
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
|
|
64
|
+
return text, used_thinking
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def extract_named_entities(query):
|
|
68
|
+
entities = set()
|
|
69
|
+
words = query.split()
|
|
70
|
+
prev_was_entity = False
|
|
71
|
+
for i, word in enumerate(words):
|
|
72
|
+
clean = word.strip('.,!?:;()[]"\'')
|
|
73
|
+
if not clean:
|
|
74
|
+
prev_was_entity = False
|
|
75
|
+
continue
|
|
76
|
+
is_entity = False
|
|
77
|
+
if clean.isupper() and len(clean) >= 2:
|
|
78
|
+
entities.add(clean.lower()); is_entity = True
|
|
79
|
+
elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
|
|
80
|
+
entities.add(clean.lower()); is_entity = True
|
|
81
|
+
elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
|
|
82
|
+
entities.add(clean.lower()); is_entity = True
|
|
83
|
+
elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
|
|
84
|
+
entities.add(clean.lower()); is_entity = True
|
|
85
|
+
elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
|
|
86
|
+
entities.add(clean.lower()); is_entity = True
|
|
87
|
+
prev_was_entity = is_entity
|
|
88
|
+
return entities
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def get_key_terms(query):
|
|
92
|
+
return set(query.lower().split()) - KEY_TERM_STOPWORDS
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def lex_preserves_key_terms(lex_line, query):
|
|
96
|
+
key_terms = get_key_terms(query)
|
|
97
|
+
return not key_terms or bool(key_terms & set(lex_line.lower().split()))
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def lex_preserves_entities(line, entities):
|
|
101
|
+
if not entities:
|
|
102
|
+
return True
|
|
103
|
+
return any(e in line.lower() for e in entities)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def lex_is_generic(lex_line):
|
|
107
|
+
lower = lex_line.lower().strip()
|
|
108
|
+
for phrase in GENERIC_LEX_PHRASES:
|
|
109
|
+
if phrase in lower or lower.startswith(phrase.split()[0]):
|
|
110
|
+
remaining = lower
|
|
111
|
+
for word in phrase.split():
|
|
112
|
+
remaining = remaining.replace(word, '', 1).strip()
|
|
113
|
+
if len(remaining) < 3:
|
|
114
|
+
return True
|
|
115
|
+
return False
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def word_set_distance(a, b):
|
|
119
|
+
return len(set(a.lower().split()) ^ set(b.lower().split()))
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def is_diverse(a, b, min_distance=2):
|
|
123
|
+
a, b = a.lower().strip(), b.lower().strip()
|
|
124
|
+
if a == b or a in b or b in a:
|
|
125
|
+
return False
|
|
126
|
+
return word_set_distance(a, b) >= min_distance
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def echoes_query(expansion, query):
|
|
130
|
+
exp, q = expansion.lower().strip(), query.lower().strip()
|
|
131
|
+
return exp == q or (q in exp and len(exp) < len(q) + 10)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def word_repetition_penalty(text):
|
|
135
|
+
counts = Counter(re.findall(r'\b\w+\b', text.lower()))
|
|
136
|
+
return sum((c - 2) * 2 for w, c in counts.items()
|
|
137
|
+
if c >= 3 and w not in STOPWORDS and len(w) > 2)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def score_expansion(query, expansion):
|
|
141
|
+
"""Score expansion as float in [0.0, 1.0] for RL reward."""
|
|
142
|
+
text, used_thinking = clean_model_output(expansion.strip())
|
|
143
|
+
|
|
144
|
+
if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
|
|
145
|
+
return 0.0
|
|
146
|
+
for line in text.split("\n"):
|
|
147
|
+
line = line.strip()
|
|
148
|
+
if line and not line.startswith(("lex:", "vec:", "hyde:")):
|
|
149
|
+
return 0.0
|
|
150
|
+
|
|
151
|
+
parsed = parse_expansion(text)
|
|
152
|
+
|
|
153
|
+
format_score = 10
|
|
154
|
+
if parsed["lex"]: format_score += 10
|
|
155
|
+
if parsed["vec"]: format_score += 10
|
|
156
|
+
|
|
157
|
+
diversity_score = 0
|
|
158
|
+
if sum(1 for t in ("lex", "vec") if parsed[t]) >= 2: diversity_score += 10
|
|
159
|
+
if len(parsed["lex"]) + len(parsed["vec"]) >= 2: diversity_score += 5
|
|
160
|
+
lex_div = 5
|
|
161
|
+
for i, a in enumerate(parsed["lex"]):
|
|
162
|
+
for b in parsed["lex"][i+1:]:
|
|
163
|
+
if not is_diverse(a, b, 2): lex_div -= 2
|
|
164
|
+
diversity_score += max(0, lex_div)
|
|
165
|
+
vec_div = 5
|
|
166
|
+
for i, a in enumerate(parsed["vec"]):
|
|
167
|
+
for b in parsed["vec"][i+1:]:
|
|
168
|
+
if not is_diverse(a, b, 3): vec_div -= 2
|
|
169
|
+
diversity_score += max(0, vec_div)
|
|
170
|
+
echo = 5
|
|
171
|
+
for exp in parsed["lex"] + parsed["vec"]:
|
|
172
|
+
if echoes_query(exp, query): echo -= 3
|
|
173
|
+
diversity_score += max(0, echo)
|
|
174
|
+
|
|
175
|
+
hyde_score = 0
|
|
176
|
+
if parsed["hyde"]:
|
|
177
|
+
hyde_text = parsed["hyde"][0]
|
|
178
|
+
hyde_score += 5
|
|
179
|
+
if 50 <= len(hyde_text) <= 200: hyde_score += 5
|
|
180
|
+
elif len(hyde_text) < 50: hyde_score += 2
|
|
181
|
+
if "\n" not in hyde_text: hyde_score += 5
|
|
182
|
+
hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
|
|
183
|
+
|
|
184
|
+
quality_score = 5
|
|
185
|
+
if parsed["lex"] and parsed["vec"]:
|
|
186
|
+
avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
|
|
187
|
+
avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
|
|
188
|
+
if avg_lex <= avg_vec: quality_score += 5
|
|
189
|
+
if parsed["vec"]:
|
|
190
|
+
natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
|
|
191
|
+
quality_score += 5 if natural == len(parsed["vec"]) else 2
|
|
192
|
+
if parsed["lex"]:
|
|
193
|
+
with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
|
|
194
|
+
if with_terms == len(parsed["lex"]): quality_score += 5
|
|
195
|
+
elif with_terms > 0: quality_score += 2
|
|
196
|
+
|
|
197
|
+
entity_score = 0
|
|
198
|
+
entities = extract_named_entities(query)
|
|
199
|
+
if entities and parsed["lex"]:
|
|
200
|
+
with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
|
|
201
|
+
if with_entities == len(parsed["lex"]): entity_score += 15
|
|
202
|
+
elif with_entities > 0: entity_score += 5
|
|
203
|
+
else: entity_score -= 30
|
|
204
|
+
generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
|
|
205
|
+
if generic_count: entity_score -= generic_count * 15
|
|
206
|
+
if parsed["vec"]:
|
|
207
|
+
vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
|
|
208
|
+
if vec_with > 0: entity_score += 5
|
|
209
|
+
elif not entities:
|
|
210
|
+
entity_score = 10
|
|
211
|
+
|
|
212
|
+
think_bonus = 0 if used_thinking else 20
|
|
213
|
+
total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
|
|
214
|
+
max_possible = 140 if parsed["hyde"] else 120
|
|
215
|
+
return max(0.0, min(1.0, total / max_possible))
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def extract_query_from_prompt(prompt):
|
|
219
|
+
"""Extract the search query from a formatted prompt string."""
|
|
220
|
+
if "Expand this search query:" in prompt:
|
|
221
|
+
query = prompt.split("Expand this search query:")[-1].strip()
|
|
222
|
+
if "<|im_end|>" in query:
|
|
223
|
+
query = query.split("<|im_end|>")[0].strip()
|
|
224
|
+
return query
|
|
225
|
+
return prompt.strip()
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class QMDRewardFunction:
|
|
229
|
+
"""Reward function wrapper for TRL's GRPOTrainer."""
|
|
230
|
+
__name__ = "qmd_scoring_reward"
|
|
231
|
+
|
|
232
|
+
def __call__(self, completions, prompts=None, **kwargs):
|
|
233
|
+
rewards = []
|
|
234
|
+
for i, completion in enumerate(completions):
|
|
235
|
+
query = ""
|
|
236
|
+
if prompts and i < len(prompts):
|
|
237
|
+
query = extract_query_from_prompt(prompts[i])
|
|
238
|
+
rewards.append(score_expansion(query, completion))
|
|
239
|
+
return rewards
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
# =============================================================================
|
|
243
|
+
# Evaluation
|
|
244
|
+
# =============================================================================
|
|
245
|
+
|
|
246
|
+
EVAL_QUERIES = [
|
|
247
|
+
# Technical documentation
|
|
248
|
+
"how to configure authentication",
|
|
249
|
+
"typescript async await",
|
|
250
|
+
"docker compose networking",
|
|
251
|
+
"git rebase vs merge",
|
|
252
|
+
"react useEffect cleanup",
|
|
253
|
+
# Short/ambiguous
|
|
254
|
+
"auth", "config", "setup", "api",
|
|
255
|
+
# Named entities
|
|
256
|
+
"who is TDS motorsports",
|
|
257
|
+
"React hooks tutorial",
|
|
258
|
+
"Docker container networking",
|
|
259
|
+
"Kubernetes pod deployment",
|
|
260
|
+
"AWS Lambda functions",
|
|
261
|
+
# Personal notes / journals
|
|
262
|
+
"meeting notes project kickoff",
|
|
263
|
+
"ideas for new feature",
|
|
264
|
+
"todo list app architecture",
|
|
265
|
+
# Research / learning
|
|
266
|
+
"what is dependency injection",
|
|
267
|
+
"difference between sql and nosql",
|
|
268
|
+
"kubernetes vs docker swarm",
|
|
269
|
+
# Error/debugging
|
|
270
|
+
"connection timeout error",
|
|
271
|
+
"memory leak debugging",
|
|
272
|
+
"cors error fix",
|
|
273
|
+
# Temporal / recency
|
|
274
|
+
"recent news about Shopify",
|
|
275
|
+
"latest AI developments",
|
|
276
|
+
"best laptops right now",
|
|
277
|
+
"what changed in kubernetes latest version",
|
|
278
|
+
# Complex
|
|
279
|
+
"how to implement caching with redis in nodejs",
|
|
280
|
+
"best practices for api rate limiting",
|
|
281
|
+
"setting up ci cd pipeline with github actions",
|
|
282
|
+
]
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def generate_expansion(model, tokenizer, query, max_new_tokens=200):
|
|
286
|
+
"""Generate a query expansion using the model."""
|
|
287
|
+
messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
|
|
288
|
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
289
|
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
290
|
+
with torch.no_grad():
|
|
291
|
+
outputs = model.generate(
|
|
292
|
+
**inputs, max_new_tokens=max_new_tokens,
|
|
293
|
+
temperature=0.7, do_sample=True,
|
|
294
|
+
pad_token_id=tokenizer.pad_token_id,
|
|
295
|
+
eos_token_id=tokenizer.eos_token_id,
|
|
296
|
+
)
|
|
297
|
+
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
298
|
+
if "\nassistant\n" in full_output:
|
|
299
|
+
return full_output.split("\nassistant\n")[-1].strip()
|
|
300
|
+
elif "assistant\n" in full_output:
|
|
301
|
+
return full_output.split("assistant\n")[-1].strip()
|
|
302
|
+
return full_output[len(prompt):].strip()
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def run_eval(model, tokenizer, label, upload_repo="tobil/qmd-query-expansion-evals"):
|
|
306
|
+
"""Evaluate model on EVAL_QUERIES, print results, upload CSV."""
|
|
307
|
+
api = HfApi()
|
|
308
|
+
api.create_repo(repo_id=upload_repo, repo_type="model", exist_ok=True)
|
|
309
|
+
|
|
310
|
+
print(f"\n{'='*70}")
|
|
311
|
+
print(f" EVALUATING: {label}")
|
|
312
|
+
print(f"{'='*70}")
|
|
313
|
+
|
|
314
|
+
results = []
|
|
315
|
+
for i, query in enumerate(EVAL_QUERIES, 1):
|
|
316
|
+
expansion = generate_expansion(model, tokenizer, query)
|
|
317
|
+
score = score_expansion(query, expansion)
|
|
318
|
+
pct = round(score * 100, 1)
|
|
319
|
+
rating = ("Excellent" if pct >= 80 else "Good" if pct >= 60
|
|
320
|
+
else "Acceptable" if pct >= 40 else "Poor" if pct >= 20 else "Failed")
|
|
321
|
+
marker = "+" if pct >= 80 else "-" if pct < 60 else "~"
|
|
322
|
+
print(f" [{marker}] {i:2d}/{len(EVAL_QUERIES)} {pct:5.1f}% {rating:10s} {query}")
|
|
323
|
+
results.append({"query": query, "expansion": expansion, "score": pct, "rating": rating})
|
|
324
|
+
|
|
325
|
+
avg = sum(r["score"] for r in results) / len(results)
|
|
326
|
+
ratings = Counter(r["rating"] for r in results)
|
|
327
|
+
|
|
328
|
+
print(f"\n {'─'*50}")
|
|
329
|
+
print(f" Average score: {avg:.1f}%")
|
|
330
|
+
for r in ["Excellent", "Good", "Acceptable", "Poor", "Failed"]:
|
|
331
|
+
c = ratings.get(r, 0)
|
|
332
|
+
if c:
|
|
333
|
+
print(f" {r:10s}: {c:2d} {'█' * c}")
|
|
334
|
+
|
|
335
|
+
worst = sorted(results, key=lambda r: r["score"])[:5]
|
|
336
|
+
print(f"\n Bottom 5:")
|
|
337
|
+
for r in worst:
|
|
338
|
+
print(f" {r['score']:5.1f}% {r['query']}")
|
|
339
|
+
|
|
340
|
+
buf = io.StringIO()
|
|
341
|
+
writer = csv.writer(buf)
|
|
342
|
+
writer.writerow(["model", "query", "expansion", "score_pct", "rating"])
|
|
343
|
+
for r in results:
|
|
344
|
+
writer.writerow([label, r["query"], r["expansion"], r["score"], r["rating"]])
|
|
345
|
+
|
|
346
|
+
filename = f"eval_{label}.csv"
|
|
347
|
+
print(f"\n Uploading {filename} to {upload_repo}...")
|
|
348
|
+
api.upload_file(
|
|
349
|
+
path_or_fileobj=buf.getvalue().encode("utf-8"),
|
|
350
|
+
path_in_repo=filename,
|
|
351
|
+
repo_id=upload_repo,
|
|
352
|
+
repo_type="model",
|
|
353
|
+
)
|
|
354
|
+
print(f" Done: https://huggingface.co/{upload_repo}/blob/main/{filename}")
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
# /// script
|
|
2
|
+
# requires-python = ">=3.10"
|
|
3
|
+
# dependencies = [
|
|
4
|
+
# "transformers>=4.45.0",
|
|
5
|
+
# "peft>=0.7.0",
|
|
6
|
+
# "torch",
|
|
7
|
+
# "huggingface_hub>=0.20.0",
|
|
8
|
+
# "accelerate",
|
|
9
|
+
# ]
|
|
10
|
+
# ///
|
|
11
|
+
"""
|
|
12
|
+
Verbose eval: prints the actual expansions for every query.
|
|
13
|
+
|
|
14
|
+
hf jobs uv run --flavor a10g-small --secrets HF_TOKEN --timeout 30m jobs/eval_verbose.py
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
import re
|
|
19
|
+
import sys
|
|
20
|
+
from collections import Counter
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
from huggingface_hub import login
|
|
24
|
+
from peft import PeftModel
|
|
25
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
26
|
+
|
|
27
|
+
BASE_MODEL = "Qwen/Qwen3-1.7B"
|
|
28
|
+
SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
|
|
29
|
+
GRPO_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
|
|
30
|
+
|
|
31
|
+
QUERIES = [
|
|
32
|
+
"how to configure authentication",
|
|
33
|
+
"typescript async await",
|
|
34
|
+
"docker compose networking",
|
|
35
|
+
"git rebase vs merge",
|
|
36
|
+
"react useEffect cleanup",
|
|
37
|
+
"auth",
|
|
38
|
+
"config",
|
|
39
|
+
"setup",
|
|
40
|
+
"api",
|
|
41
|
+
"who is TDS motorsports",
|
|
42
|
+
"React hooks tutorial",
|
|
43
|
+
"Docker container networking",
|
|
44
|
+
"Kubernetes pod deployment",
|
|
45
|
+
"AWS Lambda functions",
|
|
46
|
+
"meeting notes project kickoff",
|
|
47
|
+
"ideas for new feature",
|
|
48
|
+
"todo list app architecture",
|
|
49
|
+
"what is dependency injection",
|
|
50
|
+
"difference between sql and nosql",
|
|
51
|
+
"kubernetes vs docker swarm",
|
|
52
|
+
"connection timeout error",
|
|
53
|
+
"memory leak debugging",
|
|
54
|
+
"cors error fix",
|
|
55
|
+
"recent news about Shopify",
|
|
56
|
+
"latest AI developments",
|
|
57
|
+
"best laptops right now",
|
|
58
|
+
"what changed in kubernetes latest version",
|
|
59
|
+
"how to implement caching with redis in nodejs",
|
|
60
|
+
"best practices for api rate limiting",
|
|
61
|
+
"setting up ci cd pipeline with github actions",
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def load_model(base, sft=None, grpo=None):
|
|
66
|
+
tokenizer = AutoTokenizer.from_pretrained(base)
|
|
67
|
+
if tokenizer.pad_token is None:
|
|
68
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
69
|
+
model = AutoModelForCausalLM.from_pretrained(base, torch_dtype=torch.bfloat16, device_map="auto")
|
|
70
|
+
if sft:
|
|
71
|
+
model = PeftModel.from_pretrained(model, sft)
|
|
72
|
+
model = model.merge_and_unload()
|
|
73
|
+
if grpo:
|
|
74
|
+
model = PeftModel.from_pretrained(model, grpo)
|
|
75
|
+
model.eval()
|
|
76
|
+
return model, tokenizer
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def generate(model, tokenizer, query):
|
|
80
|
+
messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
|
|
81
|
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
82
|
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
83
|
+
with torch.no_grad():
|
|
84
|
+
out = model.generate(**inputs, max_new_tokens=200, temperature=0.7, do_sample=True,
|
|
85
|
+
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
|
|
86
|
+
text = tokenizer.decode(out[0], skip_special_tokens=True)
|
|
87
|
+
if "\nassistant\n" in text:
|
|
88
|
+
text = text.split("\nassistant\n")[-1].strip()
|
|
89
|
+
elif "assistant\n" in text:
|
|
90
|
+
text = text.split("assistant\n")[-1].strip()
|
|
91
|
+
if "<think>" in text:
|
|
92
|
+
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
|
|
93
|
+
return text
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def main():
|
|
97
|
+
hf_token = os.environ.get("HF_TOKEN")
|
|
98
|
+
if hf_token:
|
|
99
|
+
login(token=hf_token)
|
|
100
|
+
|
|
101
|
+
print("Loading GRPO model...", file=sys.stderr)
|
|
102
|
+
model, tokenizer = load_model(BASE_MODEL, sft=SFT_MODEL, grpo=GRPO_MODEL)
|
|
103
|
+
|
|
104
|
+
for i, query in enumerate(QUERIES, 1):
|
|
105
|
+
expansion = generate(model, tokenizer, query)
|
|
106
|
+
print(f"\n{'='*60}")
|
|
107
|
+
print(f"[{i}/{len(QUERIES)}] {query}")
|
|
108
|
+
print(f"{'─'*60}")
|
|
109
|
+
print(expansion)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
if __name__ == "__main__":
|
|
113
|
+
main()
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
# /// script
|
|
2
|
+
# requires-python = ">=3.10"
|
|
3
|
+
# dependencies = [
|
|
4
|
+
# "trl>=0.12.0",
|
|
5
|
+
# "peft>=0.7.0",
|
|
6
|
+
# "transformers>=4.45.0",
|
|
7
|
+
# "accelerate>=0.24.0",
|
|
8
|
+
# "huggingface_hub>=0.20.0",
|
|
9
|
+
# "datasets",
|
|
10
|
+
# "bitsandbytes",
|
|
11
|
+
# "torch",
|
|
12
|
+
# ]
|
|
13
|
+
# ///
|
|
14
|
+
"""
|
|
15
|
+
GRPO training for QMD query expansion (Qwen3-1.7B).
|
|
16
|
+
|
|
17
|
+
Runs on top of merged SFT weights. Self-contained for HuggingFace Jobs:
|
|
18
|
+
hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 4h jobs/grpo.py
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
import os
|
|
22
|
+
import sys
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
from datasets import load_dataset
|
|
26
|
+
from huggingface_hub import login
|
|
27
|
+
from peft import LoraConfig, PeftModel, get_peft_model
|
|
28
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
29
|
+
from trl import GRPOTrainer, GRPOConfig
|
|
30
|
+
|
|
31
|
+
# Download eval_common.py if running as a standalone script (e.g. HF Jobs)
|
|
32
|
+
_eval_common_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "eval_common.py")
|
|
33
|
+
if not os.path.exists(_eval_common_path):
|
|
34
|
+
import urllib.request
|
|
35
|
+
_url = "https://huggingface.co/datasets/tobil/hf-cli-jobs-uv-run-scripts/resolve/main/eval_common.py"
|
|
36
|
+
_opener = urllib.request.build_opener()
|
|
37
|
+
_token = os.environ.get("HF_TOKEN", "")
|
|
38
|
+
if _token:
|
|
39
|
+
_opener.addheaders = [("Authorization", f"Bearer {_token}")]
|
|
40
|
+
with open(_eval_common_path, "wb") as _f:
|
|
41
|
+
_f.write(_opener.open(_url).read())
|
|
42
|
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
43
|
+
from eval_common import QMDRewardFunction, run_eval
|
|
44
|
+
|
|
45
|
+
# --- Config (inlined from configs/grpo.yaml) ---
|
|
46
|
+
BASE_MODEL = "Qwen/Qwen3-1.7B"
|
|
47
|
+
SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
|
|
48
|
+
OUTPUT_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
|
|
49
|
+
DATASET = "tobil/qmd-query-expansion-train"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def main():
|
|
53
|
+
hf_token = os.environ.get("HF_TOKEN")
|
|
54
|
+
if hf_token:
|
|
55
|
+
login(token=hf_token)
|
|
56
|
+
|
|
57
|
+
print(f"Loading tokenizer from {BASE_MODEL}...")
|
|
58
|
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
|
59
|
+
if tokenizer.pad_token is None:
|
|
60
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
61
|
+
|
|
62
|
+
# Load and format dataset
|
|
63
|
+
print(f"Loading dataset: {DATASET}...")
|
|
64
|
+
dataset = load_dataset(DATASET, split="train")
|
|
65
|
+
|
|
66
|
+
def extract_prompt(example):
|
|
67
|
+
content = example["messages"][0]["content"]
|
|
68
|
+
messages = [{"role": "user", "content": content}]
|
|
69
|
+
formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
70
|
+
return {"prompt": formatted}
|
|
71
|
+
|
|
72
|
+
dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
|
|
73
|
+
dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
|
|
74
|
+
print(f"Using {len(dataset)} prompts for GRPO")
|
|
75
|
+
|
|
76
|
+
# Load base model, merge SFT adapter
|
|
77
|
+
print(f"Loading base model {BASE_MODEL}...")
|
|
78
|
+
base_model = AutoModelForCausalLM.from_pretrained(
|
|
79
|
+
BASE_MODEL, torch_dtype=torch.bfloat16, device_map="auto",
|
|
80
|
+
)
|
|
81
|
+
print(f"Merging SFT adapter {SFT_MODEL}...")
|
|
82
|
+
model = PeftModel.from_pretrained(base_model, SFT_MODEL)
|
|
83
|
+
model = model.merge_and_unload()
|
|
84
|
+
print("SFT adapter merged.")
|
|
85
|
+
|
|
86
|
+
# Fresh LoRA for GRPO (small: rank 4, q/v only)
|
|
87
|
+
grpo_lora = LoraConfig(
|
|
88
|
+
r=4, lora_alpha=8, lora_dropout=0.05,
|
|
89
|
+
bias="none", task_type="CAUSAL_LM",
|
|
90
|
+
target_modules=["q_proj", "v_proj"],
|
|
91
|
+
)
|
|
92
|
+
model = get_peft_model(model, grpo_lora)
|
|
93
|
+
model.print_trainable_parameters()
|
|
94
|
+
|
|
95
|
+
config = GRPOConfig(
|
|
96
|
+
output_dir="qmd-query-expansion-1.7B-grpo",
|
|
97
|
+
push_to_hub=True,
|
|
98
|
+
hub_model_id=OUTPUT_MODEL,
|
|
99
|
+
|
|
100
|
+
num_generations=4,
|
|
101
|
+
max_completion_length=200,
|
|
102
|
+
beta=0.04, # KL regularization — prevents drift from SFT checkpoint
|
|
103
|
+
|
|
104
|
+
num_train_epochs=1,
|
|
105
|
+
per_device_train_batch_size=2,
|
|
106
|
+
gradient_accumulation_steps=8,
|
|
107
|
+
learning_rate=5e-7,
|
|
108
|
+
max_grad_norm=0.5,
|
|
109
|
+
max_steps=200,
|
|
110
|
+
|
|
111
|
+
logging_steps=10,
|
|
112
|
+
save_strategy="epoch",
|
|
113
|
+
bf16=True,
|
|
114
|
+
|
|
115
|
+
report_to="none",
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
print("Initializing GRPO trainer...")
|
|
119
|
+
trainer = GRPOTrainer(
|
|
120
|
+
model=model,
|
|
121
|
+
processing_class=tokenizer,
|
|
122
|
+
args=config,
|
|
123
|
+
train_dataset=dataset,
|
|
124
|
+
reward_funcs=[QMDRewardFunction()],
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
print("Starting GRPO training...")
|
|
128
|
+
trainer.train()
|
|
129
|
+
|
|
130
|
+
print("Pushing to Hub...")
|
|
131
|
+
trainer.push_to_hub()
|
|
132
|
+
print(f"Done! Model: https://huggingface.co/{OUTPUT_MODEL}")
|
|
133
|
+
|
|
134
|
+
# --- Automatic evaluation ---
|
|
135
|
+
print("\nStarting automatic evaluation...")
|
|
136
|
+
trainer.model.eval()
|
|
137
|
+
run_eval(trainer.model, tokenizer, "grpo")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
if __name__ == "__main__":
|
|
141
|
+
main()
|